4.3.4 Gradient - WangLabTHU/GPro GitHub Wiki
hcwang and qxdu edited on Aug 4, 2023, 1 version
Gradient descent is an optimization algorithm used to find the values of the parameters (coefficients) of function (f) that minimize the cost function (cost). In our model, the predictor scores the generated sequence and optimizes the search hidden space by backpropagation of the obtained gradient.
Gradient descent algorithm can be used for implicit space optimization of our WGAN model. The schematic diagram of a workflow is shown below[1].
Caution: The current algorithm defaults to using the WGAN generator and CNNK15 predictor. Please provide the model you have already trained. This program will search for the most effective hidden space
params | description | default value |
---|---|---|
generator_modelpath | trained model path of generator | None |
predictor_modelpath | trained model path of predictor | None |
natural_datapath | natural sequences datapath | None |
savepath | final results saving directory | None |
z_dim | dimension of hidden state for WGAN model | 128 |
seq_len | sequence length | 50 |
is_rnn | if you choose "AttnBiLSTM" or other RNN-related network, set "True" | False |
params | description | default value |
---|---|---|
MaxIter | Maximum Iteration epoch | 1000 |
MaxPoolsize | length of final selecting results | 2000 |
learning_rate | learning rate of hidden space optimization | 0.01 |
mode | maximum or minimum the results | "max" |
Before executing optimizer, you should have trained a generator and a predictor.
A simple demo will work like:
from gpro.optimizer.model_driven.gradient import GradientAlgorithm
# (1) define the generator
default_root = "your working directory"
generator_modelpath = os.path.join(str(default_root), 'checkpoints/wgan/checkpoints/net_G_12.pth')
# (2) define the predictor
from gpro.predictor.cnn_k15.cnn_k15 import CNN_K15_language
predictor = CNN_K15_language(length=50)
predictor_modelpath = os.path.join(default_root), 'checkpoints/cnn_k15/checkpoint.pth')
# (3) select the highly-expressed sequence
natural_datapath = default_root + '/data/diffusion_prediction/seq.txt'
tmp = GradientAlgorithm(predictor = predictor,
generator_modelpath=generator_modelpath, predictor_modelpath=predictor_modelpath,
natural_datapath=natural_datapath, savepath="./optimization/Gradient")
tmp.run()
Resulting files consists of compared_with_natural.pdf
, ExpIter.txt
, ExpIter.csv
, z.png
files | description |
---|---|
compared_with_natural.pdf | Box plot comparing model generated results with natural results |
ExpIter.txt | Save the FASTA file for the final result sequence |
ExpIter.csv | Save the sequences and predictions for the final result sequence |
z.png | Changes in the first dimension of the hidden space before and after optimizing |
A box plot is shown below.
z.png
is shown below:
[1] Amini, Alexander & Amini, Ava & Karaman, Sertac & Rus, Daniela. (2018). Spatial Uncertainty Sampling for End-to-End Control.