checkpointing and resuming - chunhualiao/public-docs GitHub Wiki
To safely restart training from where it was stopped due to a session limit (e.g., a 4-hour limit on a GPU cluster), you need to implement checkpointing and resuming mechanisms in your training script. Here's a step-by-step guide to achieve this:
1. Implement Checkpointing
Checkpointing involves saving the state of the training process (e.g., model weights, optimizer state, epoch, and other training metadata) at regular intervals. This allows you to resume training from the last saved checkpoint.
Key Components to Save in a Checkpoint
- Model Weights: The current state of the model's parameters.
- Optimizer State: The state of the optimizer (e.g., learning rate, momentum).
- Epoch/Iteration: The current epoch or iteration number.
- Training Metadata: Any other relevant information, such as loss history or random seeds.
Example: Saving a Checkpoint
import torch
def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path):
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch} to {checkpoint_path}")
2. Implement Resuming from a Checkpoint
When restarting the training script, load the last saved checkpoint to resume training from where it left off.
Example: Loading a Checkpoint
def load_checkpoint(model, optimizer, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
print(f"Resuming training from epoch {epoch} with loss {loss}")
return epoch
3. Modify the Training Script
Integrate checkpointing and resuming into your training script. Here's an example:
import torch
import os
# Define paths
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "last_checkpoint.pth")
# Initialize model and optimizer
model = ... # Your model
optimizer = ... # Your optimizer
start_epoch = 0
# Load checkpoint if it exists
if os.path.exists(checkpoint_path):
start_epoch = load_checkpoint(model, optimizer, checkpoint_path)
# Training loop
for epoch in range(start_epoch, num_epochs):
for batch in dataloader:
# Training step
...
# Save checkpoint periodically (e.g., every epoch or every N iterations)
if (epoch + 1) % checkpoint_freq == 0:
save_checkpoint(model, optimizer, epoch + 1, loss, checkpoint_path)
# Save checkpoint at the end of each epoch
save_checkpoint(model, optimizer, epoch + 1, loss, checkpoint_path)
4. Handle Session Limits on GPU Clusters
GPU clusters often impose session limits (e.g., 4 hours). To handle this:
-
Save Checkpoints Before the Session Ends:
- Save checkpoints periodically (e.g., every epoch or every N iterations) to ensure you don't lose progress if the session is terminated abruptly.
-
Use Job Scheduling Tools:
- If your cluster uses a job scheduler (e.g., Slurm, PBS), configure your job script to:
- Request the maximum allowed session time (e.g., 4 hours).
- Automatically resubmit the job after it ends.
- If your cluster uses a job scheduler (e.g., Slurm, PBS), configure your job script to:
Example: Slurm Job Script with Resubmission
#!/bin/bash
#SBATCH --job-name=training
#SBATCH --time=04:00:00 # 4-hour session limit
#SBATCH --gres=gpu:1
#SBATCH --output=training_%j.log
# Load necessary modules
module load cuda/11.3
module load python/3.9
# Activate virtual environment
source venv/bin/activate
# Run the training script
python train.py --resume
# Resubmit the job if not completed
if [ $? -eq 0 ]; then
sbatch training_job.sh
fi
5. Automate Checkpoint Management
- Checkpoint Naming:
- Use unique names for checkpoints (e.g.,
checkpoint_epoch_{epoch}.pth
) to avoid overwriting.
- Use unique names for checkpoints (e.g.,
- Clean Up Old Checkpoints:
- Periodically delete old checkpoints to save disk space, but always keep the most recent one.
Example: Clean Up Old Checkpoints
import glob
import os
def clean_old_checkpoints(checkpoint_dir, keep_last=3):
checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.pth")))
for checkpoint in checkpoints[:-keep_last]:
os.remove(checkpoint)
print(f"Deleted old checkpoint: {checkpoint}")
6. Test the Restart Mechanism
Before running long training sessions, test the restart mechanism:
- Start training and let it run for a few iterations.
- Manually stop the training (e.g., by killing the process).
- Restart the training and verify that it resumes correctly from the last checkpoint.
7. Monitor Training Progress
- Use logging tools (e.g., TensorBoard, Weights & Biases) to track training progress and ensure that the resume mechanism works as expected.
- Log the epoch, iteration, and loss values to confirm that training resumes correctly.
Summary
To safely restart training from where it was stopped due to a session limit:
- Implement checkpointing to save the model, optimizer, and training state periodically.
- Implement resuming to load the last checkpoint and continue training.
- Use job scheduling tools to automatically resubmit jobs after the session limit.
- Test the restart mechanism to ensure it works correctly.
By following these steps, you can ensure that your training process is robust to session limits and can resume seamlessly from the last saved checkpoint.