Load, Save Model - Paxoo/PyTorch-Best_Practices GitHub Wiki
State_dict in PyTorch
https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html
Saving and loading Checkpoints
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html
# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4
#save
torch.save({
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}, PATH)
#load
model = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
Saving and loading model for inference
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html
# Specify a path
PATH = "state_dict_model.pt"
# Save
torch.save(net.state_dict(), PATH)
# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
Saving best model
def save_checkpoint(state, save_path: str, is_best: bool = False):
"""Saves torch model to checkpoint file.
Args:
state (torch model state): State of a torch Neural Network
save_path (str): Destination path for saving checkpoint
is_best (bool): If ``True`` creates additional copy ``best_model.ckpt``
"""
# save checkpoint
torch.save(state, save_path)
# copy best
if is_best:
shutil.copyfile(save_path, os.path.join(save_dir, 'best_model.ckpt'))
def main():
...
model.val()
val_best_accuracy = 0
with torch.no_grad():
for batch in val_loader:
...
if current_accuracy > val_best_accuracy:
val_best_accuracy = current_accuracy
# save checkpoint
cpkt = {
'epoch': EPOCH,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': LOSS,
}
save_checkpoint(cpkt, 'model_checkpoint.ckpt', True)