cart pole:simulate - chunhualiao/public-docs GitHub Wiki

cart pole

Detailed Explanation of simulate_agent()

The simulate_agent() function loads a trained Deep Q-Network (DQN) model and runs a simulation of the CartPole-v1 environment to visualize how well the agent performs after training. This function uses real-time rendering and OpenCV visualization to display the agent's performance.


1. Function Purpose

  • Loads the trained model (cartpole_dqn.pth).
  • Runs the trained agent in the environment without learning.
  • Visualizes the agent's performance using OpenCV.
  • Displays time step and total reward on the rendered simulation.

2. Code Breakdown

Step 1: Initialize the Environment and Load the Model

def simulate_agent(model_path="cartpole_dqn.pth", render=True):
    env = gym.make('CartPole-v1', render_mode="rgb_array")
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    agent = Agent(state_size, action_size)
    agent.model.load_state_dict(torch.load(model_path))
    agent.model.eval()  # Set the model to evaluation mode
  • gym.make('CartPole-v1', render_mode="rgb_array")

    • Creates the CartPole environment in image-rendering mode (rgb_array).
    • This allows the simulation to generate image frames.
  • agent = Agent(state_size, action_size)

    • Creates an Agent object with the same state and action space as during training.
  • agent.model.load_state_dict(torch.load(model_path))

    • Loads the trained model from the saved checkpoint file (cartpole_dqn.pth).
  • agent.model.eval()

    • Sets the model to evaluation mode to disable dropout and batch normalization.

Step 2: Disable Exploration for Testing

agent.epsilon = -1  # Ensure no exploration during simulation
  • In training, the agent follows an ε-greedy strategy, meaning it sometimes chooses random actions for exploration.
  • Here, epsilon = -1 ensures the agent always selects the best action (pure exploitation).

Step 3: Reset the Environment and Initialize Variables

state, _ = env.reset()
done = False
truncated = False
total_reward = 0
timestep = 0
  • Resets the environment to get the initial state.
  • done = False and truncated = False control when the simulation stops.
  • total_reward = 0 tracks cumulative rewards.
  • timestep = 0 counts the number of steps taken.

Step 4: Start the Simulation Loop

while not (done or truncated):
  • Runs the simulation until the episode ends (either done or truncated is True).

Step 5: Render and Display the Simulation Using OpenCV

if render:
    img = env.render()  # Render the environment and get the image frame
    # Add text overlay for timestep and total reward
    text = f"Timestep: {timestep}, Reward: {total_reward}"
    cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) # Black color
    cv2.imshow("CartPole Simulation", img)
    cv2.waitKey(20)  # Wait for 20ms, which controls the speed of the simulation

Rendering the Simulation

  • img = env.render()
    • Generates the current frame of the CartPole environment.
    • This is an RGB image showing the cart, pole, and track.

Adding a Text Overlay with OpenCV

  • cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
    • Adds a real-time text overlay to the image:
      • Displays current timestep.
      • Shows total reward earned so far.
    • Text is positioned at (10, 30) pixels in the frame.
    • Uses black text color.

Displaying the Simulation

  • cv2.imshow("CartPole Simulation", img)
    • Displays the rendered frame in an OpenCV window.
  • cv2.waitKey(20)
    • Waits for 20 milliseconds between frames, controlling simulation speed.

Step 6: Select the Best Action and Step Forward

action = agent.act(state)
next_state, reward, done, truncated, _ = env.step(action)
total_reward += reward
state = next_state
timestep += 1
  • action = agent.act(state)

    • The trained agent selects an action without randomness.
  • env.step(action)

    • The environment executes the action, returning:
      • New state (next_state)
      • Reward received (reward)
      • Episode termination flag (done)
      • Truncation flag (truncated)
  • Update counters:

    • total_reward += reward → Tracks total accumulated reward.
    • timestep += 1 → Increments step count.

Step 7: Display Final Results and Close Windows

print(f"Simulation ended with reward: {total_reward}")
cv2.destroyAllWindows() # Close the OpenCV windows when simulation ends
env.close()
  • print(f"Simulation ended with reward: {total_reward}")

    • Displays the final total reward at the end of the episode.
  • cv2.destroyAllWindows()

    • Closes OpenCV simulation window.
  • env.close()

    • Properly shuts down the Gym environment.

3. Summary of Function Execution

Step Action Code
1. Initialize environment and agent Creates the CartPole environment and loads the trained model. env = gym.make('CartPole-v1', render_mode="rgb_array")
2. Disable exploration Ensures the agent always selects the best action. agent.epsilon = -1
3. Reset the environment Starts a new simulation episode. state, _ = env.reset()
4. Run simulation loop Continues until the pole falls or 500 steps reached. while not (done or truncated):
5. Render frame & overlay text Uses OpenCV to visualize the simulation. cv2.putText(img, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, ...)
6. Select action & execute in environment Uses trained model to select and execute actions. action = agent.act(state)next_state, reward, done, truncated, _ = env.step(action)
7. Display final reward & close simulation Ends the visualization and shuts down the environment. print(f"Simulation ended with reward: {total_reward}")

4. How the Visualization Works

  • The CartPole simulation is visualized frame by frame.
  • Each frame is rendered as an RGB image.
  • Real-time statistics (Timestep & Reward) are overlaid on the image using OpenCV.
  • The simulation runs at a controlled speed (cv2.waitKey(20)).
  • Once the episode ends, the window closes and the total reward is printed.

5. Expected Output

During the simulation, an OpenCV window displays the CartPole game with a real-time text overlay:

Timestep: 30, Reward: 30
Timestep: 60, Reward: 60
...
Timestep: 500, Reward: 500
Simulation ended with reward: 500

If the agent is well-trained, it balances the pole for 500 timesteps (max score). If it fails early, the reward is lower.


6. Conclusion

The simulate_agent() function:

  • Loads the trained model and runs a simulation.
  • Uses OpenCV to display real-time text overlays.
  • Runs the agent in pure exploitation mode (no randomness).
  • Provides visual feedback on how well the agent performs.

This function is a key tool for evaluating and debugging the agent's performance after training! 🚀 Let me know if you need more details!