cart pole: checkpoint vs. final model files - chunhualiao/public-docs GitHub Wiki
Okay, I've read the content of cart-pole.py
.
From the code, I can see that:
Checkpoint Models
Checkpoint models are saved periodically during training in the train_agent
function every 5 minutes using this logic:
if time.time() - start_time > 300: # 300 seconds = 5 minutes
timestamp = time.strftime("%Y%m%d_%H%M%S")
checkpoint_path = f"cartpole_checkpoint_{timestamp}.pth"
torch.save({
'model_state_dict': agent.model.state_dict(),
'optimizer_state_dict': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'scores': scores,
'step_count': step_count
}, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}.")
start_time = time.time()
This saves the model's state_dict
, optimizer's state_dict
, epsilon value, scores, and step count. The filename includes a timestamp, so each checkpoint is unique.
Final Model Saving
The final model is saved at the end of the training process, specifically when the average score reaches the target average score (target_avg_score = 450.0
):
if avg_score >= target_avg_score:
print(f"Solved in {e} episodes!")
agent.save("cartpole_dqn.pth")
break
This calls the agent.save("cartpole_dqn.pth")
function.
The agent.save
function is defined as:
def save(self, filename):
"""
Save model weights.
"""
torch.save(self.model.state_dict(), filename)
This function only saves the model.state_dict
, which contains the learned weights of the neural network.
Model Loading in Simulation
The simulate_agent
function loads the final model by default:
def simulate_agent(model_path="cartpole_dqn.pth", render=True):
"""
Loads a trained DQN model and simulates the CartPole game with graphics.
"""
agent.model.load_state_dict(torch.load(model_path))
...
It loads the model from "cartpole_dqn.pth"
by default, which is the final model saved after training is complete.
Key Differences Between Checkpoint and Final Models
1. Different Saving Frequency and Purpose
- Checkpoints are saved periodically during training as backups and to allow resuming training. They are more frequent and capture the state of training at different time points.
- Final model is saved only once at the end of successful training, representing the converged or best-performing model.
2. Different Content Saved
- Checkpoints save more information than the final model:
model_state_dict
(model weights)optimizer_state_dict
epsilon
- Training scores
step_count
- Final model only saves
model_state_dict
because, for deployment or simulation, only the trained network weights are necessary.
3. Compatibility Issues
- While the model architecture is the same for both checkpoint and final models (both are instances of the
DQN
class), their internal content differs. - A checkpoint file cannot be directly used where a final model file is expected because the checkpoint file contains extra information that the loading function in
simulate_agent
is not designed to handle. - The
simulate_agent
function specifically loads only themodel_state_dict
.
Summary
- Checkpoint files and the final model file serve different purposes.
- They are saved at different times and contain different sets of information.
- While both contain the model's learned weights, checkpoints include additional training state information for resuming training, whereas the final model is a lightweight file containing only the weights needed for inference or simulation.
- They are not directly interchangeable, even though they are based on the same model architecture.