4.2.5 AttnBiLSTM - WangLabTHU/GPro GitHub Wiki
hcwang and qxdu edited on Aug 4, 2023, 1 version
AttnBiLSTM is combined with multihead attention layers and bi-directional LSTM structure. It is proposed by Aviv's research ([1]) for predicting the effect of mutations in non-coding regulatory DNA sequences, predicting which regulatory mutations affect expression and fitness. This model is suitable for also long sequences (>100bp), and can achieve good enough performance.A schematic diagram of the whole workflow has been provided.
Here, in order to facilitate users' understanding of the process of the model in biological sequences, we provide a more detailed operational pipeline.
All parameters should be defined during the initialization phase. We have encapsulated the source code, thus all predictive models have unified input and output parameters. There are two types of parameters, one should be defined during the initialization phase (Initialization
), and the other should be defined during the training/sampling phase (Training/Predicting
).
params | description | default value |
---|---|---|
batch_size | training batch size | 64 |
length | sequential length of the training dataset | 50 |
model_name | parameter that controls the saving path under "./checkpoints" | attnbilstm |
epoch | training epochs | 200 |
patience | earlystopping when the indicators no longer change | 50 |
log_steps | logging the output/criterias of the model every print_epoch epochs | 10 |
save_steps | saving the result of model every save_epoch epochs | 20 |
exp_mode | the processing mode for expression input | log2 |
params | description | default value | flexible stage |
---|---|---|---|
dataset | training dataset sequences path, fasta file | None | train() |
labels | training dataset expression path, txt file, each line an expression corresponding to dataset | None | train() |
savepath | final model saving path directory | None | train() |
model_path | model loading directory | None |
predict() /predict_input()
|
data_path | dataset to be predicted , fasta file | None | predict() |
inputs | data for predict_input, can be datapath, sequence list or onehot encoded data | None | predict_input() |
mode | input mode for predict_input, can be "path","data" or "onehot" | "path" | predict_input() |
Caution: predict()
function will directly generate samples in checkpoint path, but predict_input()
will not generate the file automatically.
A demo for model training/predicting is described below:
from gpro.predictor.attnbilstm.attnbilstm import AttnBilstm_language
model = AttnBilstm_language(length=50, epoch=400, patience=10)
# Train
default_root = "your working directory"
dataset = os.path.join(default_root, 'data/seq.txt')
labels = os.path.join(default_root, 'data/exp.txt')
save_path = os.path.join(default_root, 'checkpoints/')
model.train(dataset=dataset,labels=labels,savepath=save_path)
# Predict
model_path = os.path.join(default_root, "checkpoints/attnbilstm/checkpoint.pth")
data_path = os.path.join(default_root, "data/example.txt")
model.predict(model_path=model_path, data_path=data_path)
# Predict input
res = model.predict_input(model_path=model_path, inputs=data_path)
print(res)
[1] Vaishnav E D, de Boer C G, Molinet J, et al. The evolution, evolvability and engineering of gene regulatory DNA[J]. Nature, 2022, 603(7901): 455-463.