cart pole: checkpoint vs. final model files - chunhualiao/public-docs GitHub Wiki

Okay, I've read the content of cart-pole.py.

From the code, I can see that:

Checkpoint Models

Checkpoint models are saved periodically during training in the train_agent function every 5 minutes using this logic:

if time.time() - start_time > 300:  # 300 seconds = 5 minutes
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    checkpoint_path = f"cartpole_checkpoint_{timestamp}.pth"
    torch.save({
        'model_state_dict': agent.model.state_dict(),
        'optimizer_state_dict': agent.optimizer.state_dict(),
        'epsilon': agent.epsilon,
        'scores': scores,
        'step_count': step_count
    }, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}.")
    start_time = time.time()

This saves the model's state_dict, optimizer's state_dict, epsilon value, scores, and step count. The filename includes a timestamp, so each checkpoint is unique.

Final Model Saving

The final model is saved at the end of the training process, specifically when the average score reaches the target average score (target_avg_score = 450.0):

if avg_score >= target_avg_score:
    print(f"Solved in {e} episodes!")
    agent.save("cartpole_dqn.pth")
    break

This calls the agent.save("cartpole_dqn.pth") function.

The agent.save function is defined as:

def save(self, filename):
    """
    Save model weights.
    """
    torch.save(self.model.state_dict(), filename)

This function only saves the model.state_dict, which contains the learned weights of the neural network.

Model Loading in Simulation

The simulate_agent function loads the final model by default:

def simulate_agent(model_path="cartpole_dqn.pth", render=True):
    """
    Loads a trained DQN model and simulates the CartPole game with graphics.
    """
    agent.model.load_state_dict(torch.load(model_path))
    ...

It loads the model from "cartpole_dqn.pth" by default, which is the final model saved after training is complete.


Key Differences Between Checkpoint and Final Models

1. Different Saving Frequency and Purpose

  • Checkpoints are saved periodically during training as backups and to allow resuming training. They are more frequent and capture the state of training at different time points.
  • Final model is saved only once at the end of successful training, representing the converged or best-performing model.

2. Different Content Saved

  • Checkpoints save more information than the final model:
    • model_state_dict (model weights)
    • optimizer_state_dict
    • epsilon
    • Training scores
    • step_count
  • Final model only saves model_state_dict because, for deployment or simulation, only the trained network weights are necessary.

3. Compatibility Issues

  • While the model architecture is the same for both checkpoint and final models (both are instances of the DQN class), their internal content differs.
  • A checkpoint file cannot be directly used where a final model file is expected because the checkpoint file contains extra information that the loading function in simulate_agent is not designed to handle.
  • The simulate_agent function specifically loads only the model_state_dict.

Summary

  • Checkpoint files and the final model file serve different purposes.
  • They are saved at different times and contain different sets of information.
  • While both contain the model's learned weights, checkpoints include additional training state information for resuming training, whereas the final model is a lightweight file containing only the weights needed for inference or simulation.
  • They are not directly interchangeable, even though they are based on the same model architecture.