checkpointing and resuming - chunhualiao/public-docs GitHub Wiki

TinyZero>

To safely restart training from where it was stopped due to a session limit (e.g., a 4-hour limit on a GPU cluster), you need to implement checkpointing and resuming mechanisms in your training script. Here's a step-by-step guide to achieve this:


1. Implement Checkpointing

Checkpointing involves saving the state of the training process (e.g., model weights, optimizer state, epoch, and other training metadata) at regular intervals. This allows you to resume training from the last saved checkpoint.

Key Components to Save in a Checkpoint

  • Model Weights: The current state of the model's parameters.
  • Optimizer State: The state of the optimizer (e.g., learning rate, momentum).
  • Epoch/Iteration: The current epoch or iteration number.
  • Training Metadata: Any other relevant information, such as loss history or random seeds.

Example: Saving a Checkpoint

import torch

def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch} to {checkpoint_path}")

2. Implement Resuming from a Checkpoint

When restarting the training script, load the last saved checkpoint to resume training from where it left off.

Example: Loading a Checkpoint

def load_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    print(f"Resuming training from epoch {epoch} with loss {loss}")
    return epoch

3. Modify the Training Script

Integrate checkpointing and resuming into your training script. Here's an example:

import torch
import os

# Define paths
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "last_checkpoint.pth")

# Initialize model and optimizer
model = ...  # Your model
optimizer = ...  # Your optimizer
start_epoch = 0

# Load checkpoint if it exists
if os.path.exists(checkpoint_path):
    start_epoch = load_checkpoint(model, optimizer, checkpoint_path)

# Training loop
for epoch in range(start_epoch, num_epochs):
    for batch in dataloader:
        # Training step
        ...

        # Save checkpoint periodically (e.g., every epoch or every N iterations)
        if (epoch + 1) % checkpoint_freq == 0:
            save_checkpoint(model, optimizer, epoch + 1, loss, checkpoint_path)

    # Save checkpoint at the end of each epoch
    save_checkpoint(model, optimizer, epoch + 1, loss, checkpoint_path)

4. Handle Session Limits on GPU Clusters

GPU clusters often impose session limits (e.g., 4 hours). To handle this:

  1. Save Checkpoints Before the Session Ends:

    • Save checkpoints periodically (e.g., every epoch or every N iterations) to ensure you don't lose progress if the session is terminated abruptly.
  2. Use Job Scheduling Tools:

    • If your cluster uses a job scheduler (e.g., Slurm, PBS), configure your job script to:
      • Request the maximum allowed session time (e.g., 4 hours).
      • Automatically resubmit the job after it ends.

Example: Slurm Job Script with Resubmission

#!/bin/bash
#SBATCH --job-name=training
#SBATCH --time=04:00:00  # 4-hour session limit
#SBATCH --gres=gpu:1
#SBATCH --output=training_%j.log

# Load necessary modules
module load cuda/11.3
module load python/3.9

# Activate virtual environment
source venv/bin/activate

# Run the training script
python train.py --resume

# Resubmit the job if not completed
if [ $? -eq 0 ]; then
    sbatch training_job.sh
fi

5. Automate Checkpoint Management

  • Checkpoint Naming:
    • Use unique names for checkpoints (e.g., checkpoint_epoch_{epoch}.pth) to avoid overwriting.
  • Clean Up Old Checkpoints:
    • Periodically delete old checkpoints to save disk space, but always keep the most recent one.

Example: Clean Up Old Checkpoints

import glob
import os

def clean_old_checkpoints(checkpoint_dir, keep_last=3):
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, "checkpoint_*.pth")))
    for checkpoint in checkpoints[:-keep_last]:
        os.remove(checkpoint)
        print(f"Deleted old checkpoint: {checkpoint}")

6. Test the Restart Mechanism

Before running long training sessions, test the restart mechanism:

  1. Start training and let it run for a few iterations.
  2. Manually stop the training (e.g., by killing the process).
  3. Restart the training and verify that it resumes correctly from the last checkpoint.

7. Monitor Training Progress

  • Use logging tools (e.g., TensorBoard, Weights & Biases) to track training progress and ensure that the resume mechanism works as expected.
  • Log the epoch, iteration, and loss values to confirm that training resumes correctly.

Summary

To safely restart training from where it was stopped due to a session limit:

  1. Implement checkpointing to save the model, optimizer, and training state periodically.
  2. Implement resuming to load the last checkpoint and continue training.
  3. Use job scheduling tools to automatically resubmit jobs after the session limit.
  4. Test the restart mechanism to ensure it works correctly.

By following these steps, you can ensure that your training process is robust to session limits and can resume seamlessly from the last saved checkpoint.