How to resume training with saved model - SoojungHong/MachineLearning GitHub Wiki
Reference
Summary
This article is about how to save your model in the form of checkpoints and how to load them back to resume training your model. It is clear that there is a need to save intermediate model states and have a mechanism to resume training. We call these intermediate model states as Checkpoints
What are the contents of a Checkpoint?
- Model parameters
- Number of epochs
- Optimizer parameters : (You need to save optimizer parameters especially when you are using Adam as your optimizer. Adam is an adaptive learning rate method, which means, it computes individual learning rates for different parameters which you would need if you would like to continue your training from where you left off!)
Saving Checkpoint in Pytorch
Note that .pt or .pth are common and recommended file extensions for saving files using PyTorch.
import torch import shutil def save_ckp(state, is_best, checkpoint_dir, best_model_dir): f_path = checkpoint_dir / 'checkpoint.pt' torch.save(state, f_path) if is_best: best_fpath = best_model_dir / 'best_model.pt' shutil.copyfile(f_path, best_fpath)
checkpoint = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)
Loading a checkpoint
def load_ckp(checkpoint_fpath, model, optimizer): checkpoint = torch.load(checkpoint_fpath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return model, optimizer, checkpoint['epoch']
Resume training
def load_ckp(checkpoint_fpath, model, optimizer): checkpoint = torch.load(checkpoint_fpath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return model, optimizer, checkpoint['epoch']
In a nutshell
Basically, you first initialize your model and optimizer and then update the state dictionaries using the load checkpoint function.
Now you can simply pass this model and optimizer to your training loop and you would notice that the model resumes training from where it left off. You can confirm this by looking at the loss values after each epoch, which is in continuation of the previously observed epochs (before training stopped).