Load, Save Model - Paxoo/PyTorch-Best_Practices GitHub Wiki

State_dict in PyTorch

https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html

Saving and loading Checkpoints

https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

#save
torch.save({
            'epoch': EPOCH,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

#load
model = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

Saving and loading model for inference

https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html

# Specify a path
PATH = "state_dict_model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

Saving best model

def save_checkpoint(state, save_path: str, is_best: bool = False):
    """Saves torch model to checkpoint file.
    Args:
        state (torch model state): State of a torch Neural Network
        save_path (str): Destination path for saving checkpoint
        is_best (bool): If ``True`` creates additional copy ``best_model.ckpt``
    """
    # save checkpoint
    torch.save(state, save_path)

    # copy best
    if is_best:
        shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))

def main():
    ...
    model.val()
    val_best_accuracy = 0
    with torch.no_grad():
        for batch in val_loader:
            ...
            if current_accuracy > val_best_accuracy:
                val_best_accuracy = current_accuracy
                # save checkpoint
                cpkt = {
                        'epoch': EPOCH,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': LOSS,
                }
                save_checkpoint(cpkt, 'model_checkpoint.ckpt', True)