4.1.2 Diffusion - WangLabTHU/GPro GitHub Wiki
hcwang and qxdu edited on Aug 4, 2023, 1 version
Non-autoregressive (NAR) text generation has attracted much attention in the field of natural language processing, which greatly reduces the inference latency but has to sacrifice the generation accuracy. Recently, diffusion models, a class of latent variable generative models, have been introduced into NAR text generation, showing an improved text generation quality. However, performing diffusion algorithm on discrete data like text remains a challenge. Here Hoogeboom et.al. proposed a new model called Multinomial Diffusion[1], which has implemented denoising/denoising methods suitable for discrete space through transformer. A schematic diagram of the diffusion process has been provided.
Here, in order to facilitate users' understanding of the process of the model in biological sequences, we provide a more detailed operational pipeline. It should be noted that the diffusion process here is achieved through a state transition matrix, and strategy similar to DDIM has been implemented.
We suggest that you define all parameters during the initialization phase. There are two types of parameters, one can only be defined during the initialization phase (Fixed
), and the other can be redefined during the initialization or training/sampling phase (Flexible
). However, in any case, a parameter can only be defined once.
params | description | default value |
---|---|---|
batch_size | training batch size | 32 |
update_freq | optimizer will update after each update_freq epoch | 1 |
lr | learning rate of Diffusion network | |
epochs | training epochs | 200 |
eval_every | performing evaluation after each eval_every epoch | 2 |
check_every | saving checkpoint after each check_every epoch | 20 |
dataset_type | data type of research | promoter |
length | sequential length of the training dataset | 50 |
diffusion_steps | diffusion steps of training or sampling | 100 |
transformer_depth | transformer layer depth | 12 |
transformer_heads | transformer head number | 16 |
transformer_local_heads | n_local_attn_heads of class LinearAttentionTransformerEmbedding | 8 |
transformer_local_size | local_attn_window_size of class LinearAttentionTransformerEmbedding, should be divisible by sequence length | 25 |
gamma | params that control optimizer | 0.99 |
model_name | parameter that controls the saving path under "./checkpoints" | diffusion |
seed | random seed, only defined in generate()
|
0 |
num_workers | for Linux system, when in Windows, please set this param to 0 | 4 |
params | description | default value | flexible stage |
---|---|---|---|
dataset | path of the training dataset | None | train() |
savepath | path for saving results | None | train() |
sample_model_path | path of the trained model | None | generate() |
sample_number | sampling number scale | None | generate() |
sample_output | path for saving samples | None | generate() |
A demo for model training/sampling is described below:
from gpro.generator.diffusion.diffusion import Diffusion_language
# model training
default_root = "your working directory"
dataset_path = os.path.join(str(default_root),'data/sequence_data.txt')
checkpoint_path = os.path.join(str(default_root), 'checkpoints/')
model = Diffusion_language(length = 50)
model.train(dataset=dataset_path, savepath=checkpoint_path)
# model sampling
sample_model_path = os.path.join(str(default_root), 'checkpoints/diffusion/check/checkpoint.pt')
sample_number = 1000
model.generate(sample_model_path, sample_number)
After the training step, you will have a check and a tb folder, and some log files under "checkpoint_path/model_name"; when you further perform sampling, you can also get a samples file that contains your samples.
/checkpoints/diffusion/model_name
├── check
│ └── checkpoint.pt
├── samples
├── tb
├── args.pickle
├── metrics_eval.pickle
├── metrics_eval.txt
├── metrics_train.pickle
└── metrics_train.txt
The detailed information of the file is as follows:
check: contains checkpoints of diffusion model
tb: contains training tensorboard
args.pickle: contains parameters for training, will be used in sampling step
metrics_eval.pickle/metrics_eval.txt: contains model performance during each evaluation
metrics_train.pickle/metrics_train: contains model performance during each training
samples: a fasta file that contains the final result of model sampling, which might be further used for biological experiments or sequence optimization.
[1] Hoogeboom E, Nielsen D, Jaini P, et al. Argmax flows and multinomial diffusion: Learning categorical distributions[J]. Advances in Neural Information Processing Systems, 2021, 34: 12454-12465.