cart pole:simulate - chunhualiao/public-docs GitHub Wiki
simulate_agent()
Detailed Explanation of The simulate_agent()
function loads a trained Deep Q-Network (DQN) model and runs a simulation of the CartPole-v1 environment to visualize how well the agent performs after training. This function uses real-time rendering and OpenCV visualization to display the agent's performance.
1. Function Purpose
- Loads the trained model (
cartpole_dqn.pth
). - Runs the trained agent in the environment without learning.
- Visualizes the agent's performance using OpenCV.
- Displays time step and total reward on the rendered simulation.
2. Code Breakdown
Step 1: Initialize the Environment and Load the Model
def simulate_agent(model_path="cartpole_dqn.pth", render=True):
env = gym.make('CartPole-v1', render_mode="rgb_array")
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = Agent(state_size, action_size)
agent.model.load_state_dict(torch.load(model_path))
agent.model.eval() # Set the model to evaluation mode
-
gym.make('CartPole-v1', render_mode="rgb_array")
- Creates the CartPole environment in image-rendering mode (
rgb_array
). - This allows the simulation to generate image frames.
- Creates the CartPole environment in image-rendering mode (
-
agent = Agent(state_size, action_size)
- Creates an Agent object with the same state and action space as during training.
-
agent.model.load_state_dict(torch.load(model_path))
- Loads the trained model from the saved checkpoint file (
cartpole_dqn.pth
).
- Loads the trained model from the saved checkpoint file (
-
agent.model.eval()
- Sets the model to evaluation mode to disable dropout and batch normalization.
Step 2: Disable Exploration for Testing
agent.epsilon = -1 # Ensure no exploration during simulation
- In training, the agent follows an ε-greedy strategy, meaning it sometimes chooses random actions for exploration.
- Here,
epsilon = -1
ensures the agent always selects the best action (pure exploitation).
Step 3: Reset the Environment and Initialize Variables
state, _ = env.reset()
done = False
truncated = False
total_reward = 0
timestep = 0
- Resets the environment to get the initial state.
done = False
andtruncated = False
control when the simulation stops.total_reward = 0
tracks cumulative rewards.timestep = 0
counts the number of steps taken.
Step 4: Start the Simulation Loop
while not (done or truncated):
- Runs the simulation until the episode ends (either
done
ortruncated
isTrue
).
Step 5: Render and Display the Simulation Using OpenCV
if render:
img = env.render() # Render the environment and get the image frame
# Add text overlay for timestep and total reward
text = f"Timestep: {timestep}, Reward: {total_reward}"
cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) # Black color
cv2.imshow("CartPole Simulation", img)
cv2.waitKey(20) # Wait for 20ms, which controls the speed of the simulation
Rendering the Simulation
img = env.render()
- Generates the current frame of the CartPole environment.
- This is an RGB image showing the cart, pole, and track.
Adding a Text Overlay with OpenCV
cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
- Adds a real-time text overlay to the image:
- Displays current timestep.
- Shows total reward earned so far.
- Text is positioned at (10, 30) pixels in the frame.
- Uses black text color.
- Adds a real-time text overlay to the image:
Displaying the Simulation
cv2.imshow("CartPole Simulation", img)
- Displays the rendered frame in an OpenCV window.
cv2.waitKey(20)
- Waits for 20 milliseconds between frames, controlling simulation speed.
Step 6: Select the Best Action and Step Forward
action = agent.act(state)
next_state, reward, done, truncated, _ = env.step(action)
total_reward += reward
state = next_state
timestep += 1
-
action = agent.act(state)
- The trained agent selects an action without randomness.
-
env.step(action)
- The environment executes the action, returning:
- New state (
next_state
) - Reward received (
reward
) - Episode termination flag (
done
) - Truncation flag (
truncated
)
- New state (
- The environment executes the action, returning:
-
Update counters:
total_reward += reward
→ Tracks total accumulated reward.timestep += 1
→ Increments step count.
Step 7: Display Final Results and Close Windows
print(f"Simulation ended with reward: {total_reward}")
cv2.destroyAllWindows() # Close the OpenCV windows when simulation ends
env.close()
-
print(f"Simulation ended with reward: {total_reward}")
- Displays the final total reward at the end of the episode.
-
cv2.destroyAllWindows()
- Closes OpenCV simulation window.
-
env.close()
- Properly shuts down the Gym environment.
3. Summary of Function Execution
Step | Action | Code |
---|---|---|
1. Initialize environment and agent | Creates the CartPole environment and loads the trained model. | env = gym.make('CartPole-v1', render_mode="rgb_array") |
2. Disable exploration | Ensures the agent always selects the best action. | agent.epsilon = -1 |
3. Reset the environment | Starts a new simulation episode. | state, _ = env.reset() |
4. Run simulation loop | Continues until the pole falls or 500 steps reached. | while not (done or truncated): |
5. Render frame & overlay text | Uses OpenCV to visualize the simulation. | cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, ...) |
6. Select action & execute in environment | Uses trained model to select and execute actions. | action = agent.act(state) → next_state, reward, done, truncated, _ = env.step(action) |
7. Display final reward & close simulation | Ends the visualization and shuts down the environment. | print(f"Simulation ended with reward: {total_reward}") |
4. How the Visualization Works
- The CartPole simulation is visualized frame by frame.
- Each frame is rendered as an RGB image.
- Real-time statistics (Timestep & Reward) are overlaid on the image using OpenCV.
- The simulation runs at a controlled speed (
cv2.waitKey(20)
). - Once the episode ends, the window closes and the total reward is printed.
5. Expected Output
During the simulation, an OpenCV window displays the CartPole game with a real-time text overlay:
Timestep: 30, Reward: 30
Timestep: 60, Reward: 60
...
Timestep: 500, Reward: 500
Simulation ended with reward: 500
If the agent is well-trained, it balances the pole for 500 timesteps (max score). If it fails early, the reward is lower.
6. Conclusion
The simulate_agent()
function:
- Loads the trained model and runs a simulation.
- Uses OpenCV to display real-time text overlays.
- Runs the agent in pure exploitation mode (no randomness).
- Provides visual feedback on how well the agent performs.
This function is a key tool for evaluating and debugging the agent's performance after training! 🚀 Let me know if you need more details!