How to resume training with saved model - SoojungHong/MachineLearning GitHub Wiki

Reference

https://medium.com/analytics-vidhya/saving-and-loading-your-model-to-resume-training-in-pytorch-cb687352fa61

Summary

This article is about how to save your model in the form of checkpoints and how to load them back to resume training your model. It is clear that there is a need to save intermediate model states and have a mechanism to resume training. We call these intermediate model states as Checkpoints

What are the contents of a Checkpoint?

  1. Model parameters
  2. Number of epochs
  3. Optimizer parameters : (You need to save optimizer parameters especially when you are using Adam as your optimizer. Adam is an adaptive learning rate method, which means, it computes individual learning rates for different parameters which you would need if you would like to continue your training from where you left off!)

Saving Checkpoint in Pytorch

Note that .pt or .pth are common and recommended file extensions for saving files using PyTorch.

import torch import shutil def save_ckp(state, is_best, checkpoint_dir, best_model_dir): f_path = checkpoint_dir / 'checkpoint.pt' torch.save(state, f_path) if is_best: best_fpath = best_model_dir / 'best_model.pt' shutil.copyfile(f_path, best_fpath)

checkpoint = { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict() } save_ckp(checkpoint, is_best, checkpoint_dir, model_dir)

Loading a checkpoint

def load_ckp(checkpoint_fpath, model, optimizer): checkpoint = torch.load(checkpoint_fpath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return model, optimizer, checkpoint['epoch']

Resume training

def load_ckp(checkpoint_fpath, model, optimizer): checkpoint = torch.load(checkpoint_fpath) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) return model, optimizer, checkpoint['epoch']

In a nutshell

Basically, you first initialize your model and optimizer and then update the state dictionaries using the load checkpoint function.

Now you can simply pass this model and optimizer to your training loop and you would notice that the model resumes training from where it left off. You can confirm this by looking at the loss values after each epoch, which is in continuation of the previously observed epochs (before training stopped).