cart pole:train - chunhualiao/public-docs GitHub Wiki

cart pole

explain train_agent() of Agent class, map code to the algorithm described in text book style

Explanation of train_agent() Function in Agent Class

The train_agent() function is responsible for training a Deep Q-Network (DQN) agent to solve the CartPole-v1 environment using reinforcement learning. It follows the standard DQN algorithm by integrating key components such as

  • experience replay,
  • ε-greedy exploration, and
  • target network updates.

Below, we break down the function into algorithmic steps and map each part of the code to its corresponding textbook-style algorithm.


1. Algorithm: Deep Q-Learning with Experience Replay

The DQN training loop follows this algorithm:

Initialize:

  1. Initialize environment $env$ and agent $\mathcal{A}$ with neural network $Q_{\theta}$
  2. Set replay memory $\mathcal{D}$ (experience buffer)
  3. Initialize target network $Q_{\theta^-}$ with weights $\theta^- \gets \theta$

Training Loop (For each episode $e$):
4. Reset environment and obtain initial state $s_0$
5. Repeat until episode terminates:

  • Select action $a_t$ using ε-greedy policy
  • Execute action $a_t$, observe reward $R_t$, and next state $s_{t+1}$
  • Store transition $(s_t, a_t, R_t, s_{t+1}, \text{done})$ in replay buffer $\mathcal{D}$
  • Sample mini-batch from replay buffer
  • Compute target Q-value using Bellman equation: $Q_{\text{target}} = R_t + \gamma \max_{a'} Q_{\theta^-}(s', a')$
  • Update network parameters $\theta$ by minimizing loss:$L = (Q_{\theta}(s, a) - Q_{\text{target}})^2$
  • Update target network periodically: $\theta^- \gets \theta$
  1. Decay exploration rate $\epsilon$
  2. Check termination condition (solved if average reward > threshold)

epsilon-greedy policy

      a_t =
      \begin{cases}  
      \text{random action}, & \text{with probability } \epsilon \\  
      \arg\max_a Q_{\theta}(s_t, a), & \text{otherwise}
      \end{cases}

2. Mapping train_agent() Code to Algorithm

Now, let’s break down the code and map it to the algorithm steps.

Step 1: Initialize Environment and Agent

env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = Agent(state_size, action_size)
  • Maps to: Step 1 & 2 of algorithm
    • Environment ( env ) is created
    • State size ( s ) and action size ( a ) are determined
    • DQN Agent ( \mathcal{A} ) is initialized

Step 2: Initialize Training Parameters

batch_size = 64
target_update_freq = 100  # Update target network every 100 steps
scores = []
step_count = 0
target_avg_score = 450.0
  • Maps to: Step 3
    • Experience replay batch size set to 64
    • Target network update frequency set to 100 steps
    • Tracking episode scores and step count

Step 3: Training Loop - Episode Execution

while True:  # Continue training indefinitely
    e += 1
    state, _ = env.reset()
    total_reward = 0
    done = False
    truncated = False
  • Maps to: Step 4
    • Reset environment at beginning of each episode
    • Set total reward for tracking performance
    • done and truncated ensure that episodes terminate properly

Step 4: Select Action Using ε-Greedy Policy

action = agent.act(state)
  • Maps to: Step 5 (Action Selection)
    • Agent chooses action based on ε-greedy policy
    • Exploration ((\epsilon)): Choose random action
    • Exploitation ((1-\epsilon)): Choose best Q-value action

Step 5: Execute Action, Observe Next State & Reward

next_state, reward, done, truncated, _ = env.step(action)
  • Maps to: Step 5 (Observe ( s', R ))
    • Agent executes action in environment
    • Observes next state ( s' ) and reward ( R )
    • Checks if episode is done

Step 6: Store Experience in Replay Buffer

agent.remember(state, action, reward, next_state, done)
  • Maps to: Step 5 (Store in Experience Buffer ( \mathcal{D} ))
    • Agent stores transition ( (s, a, R, s', \text{done}) ) in replay buffer

Step 7: Sample Mini-Batch and Train DQN

agent.replay(batch_size)
  • Maps to: Step 5 (Training using replay buffer)
    • Samples a mini-batch from replay buffer
    • Updates Q-values using Bellman equation
    • Performs gradient descent to minimize loss

Step 8: Periodic Target Network Update

if step_count % target_update_freq == 0:
    agent.update_target_model()
  • Maps to: Step 5 (Target Network Update)
    • Updates target network ( Q_{\theta^-} \gets Q_{\theta} ) every 100 steps

Step 9: Decay Exploration Rate (ε-Greedy)

agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
  • Maps to: Step 6 (Decay ( \epsilon ))
    • Reduces exploration as agent improves

Step 10: Check Performance & Save Model

if avg_score >= target_avg_score:
    print(f"Solved in {e} episodes!")
    agent.save("cartpole_dqn.pth")
    break
  • Maps to: Step 7 (Check Termination Condition)
    • If average score reaches 450+, training stops
    • Model is saved

3. Summary of Code-to-Algorithm Mapping

Algorithm Step Code in train_agent()
Initialize environment & agent env = gym.make('CartPole-v1'), agent = Agent(state_size, action_size)
Initialize replay buffer Done inside Agent class (self.memory = deque(maxlen=100000))
Training loop (for each episode) while True: (infinite loop until solved)
Reset environment at episode start state, _ = env.reset()
Select action using ε-greedy policy action = agent.act(state)
Execute action & observe reward next_state, reward, done, truncated, _ = env.step(action)
Store transition in experience buffer agent.remember(state, action, reward, next_state, done)
Sample batch & train agent.replay(batch_size)
Update target network periodically if step_count % target_update_freq == 0: agent.update_target_model()
Decay exploration rate (ε) agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
Check termination & save model if avg_score >= target_avg_score:

This structured breakdown provides textbook-style mapping of train_agent() to Deep Q-Learning. 🚀 Let me know if you need further clarifications!

Explanation of the replay() Function in Agent Class

The replay() function in the DQN Agent is responsible for training the neural network using experiences stored in the replay buffer. It updates the Q-values using the Bellman equation, which is the foundation of Q-learning.


1. The Bellman Equation in DQN

The goal of Q-learning is to update the Q-value estimates to minimize the error between the predicted Q-value and the target Q-value. The update follows:

Q(s, a) \leftarrow Q(s, a) + \alpha \left[ R + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]

where:

  • ( Q(s, a) ) → Current Q-value estimate for state ( s ) and action ( a )
  • ( \alpha ) → Learning rate (handled by the optimizer in DQN)
  • ( R ) → Reward received after taking action ( a ) in state ( s )
  • ( \gamma ) → Discount factor (how much future rewards matter)
  • ( \max_{a'} Q(s', a') ) → The maximum estimated Q-value for the next state ( s' ) (computed using the target model in stable DQN)

In DQN, instead of updating Q-values directly, we train a neural network to approximate ( Q(s, a) ).


2. Mapping replay() Code to Bellman Equation

Now, let’s break down the key steps in the replay() function and map them to the Bellman equation terms.

Step 1: Checking if Replay Buffer has Enough Data

if len(self.memory) < batch_size:
    return
  • Purpose: Ensures that training only starts when the replay buffer has accumulated enough experience.

Step 2: Sampling a Mini-Batch from Replay Buffer

minibatch = random.sample(self.memory, batch_size)
  • Purpose: Randomly selects a batch of experiences (s, a, r, s', done) from memory to break temporal correlation, which stabilizes learning.

Step 3: Converting Mini-Batch to PyTorch Tensors

states = torch.FloatTensor([t[0] for t in minibatch]).to(device)
actions = torch.LongTensor([t[1] for t in minibatch]).to(device)
rewards = torch.FloatTensor([t[2] for t in minibatch]).to(device)
next_states = torch.FloatTensor([t[3] for t in minibatch]).to(device)
dones = torch.FloatTensor([t[4] for t in minibatch]).to(device)
  • Mapping to Bellman equation terms:
    • states → ( s ) (current state)
    • actions → ( a ) (action taken)
    • rewards → ( R ) (reward received)
    • next_states → ( s' ) (next state)
    • dones → Terminal state indicator (used to stop considering future rewards when an episode ends)

Step 4: Computing Current Q-Values from Main Model (self.model)

current_q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
  • Maps to: ( Q(s, a) ) (current Q-value estimate)
  • Breakdown:
    • self.model(states): Predicts Q-values for all actions in the current state.
    • .gather(1, actions.unsqueeze(1)): Selects the Q-value corresponding to the action that was taken.

Step 5: Computing Target Q-Values Using the Target Model (self.target_model)

next_q = self.target_model(next_states).detach().max(1)[0]
target = rewards + (1 - dones) * self.gamma * next_q
  • Maps to: ( R + \gamma \max_{a'} Q(s', a') ) (target Q-value)
  • Breakdown:
    • self.target_model(next_states): Predicts Q-values for all actions in the next state ( s' ).
    • .detach(): Ensures that gradients are not computed for the target model (stabilizing updates).
    • .max(1)[0]: Selects the maximum Q-value over all possible actions ( a' ) in ( s' ).
    • (1 - dones) * self.gamma * next_q: Ensures that if the episode has ended (done = 1), the future reward is ignored.

Step 6: Computing Loss and Updating the Model

loss = F.mse_loss(current_q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
  • Maps to: Minimizing the loss function
L = \left( Q(s, a) - (R + \gamma \max_{a'} Q(s', a')) \right)^2
  • Breakdown:
    • F.mse_loss(current_q, target): Computes Mean Squared Error (MSE) between current Q-values and target Q-values.
    • self.optimizer.zero_grad(): Clears previous gradients.
    • loss.backward(): Computes gradients for the neural network.
    • self.optimizer.step(): Updates network weights using gradient descent.

3. Summary of Mappings

Bellman Equation Term Code in replay()
( s ) (current state) states = torch.FloatTensor([t[0] for t in minibatch])
( a ) (action taken) actions = torch.LongTensor([t[1] for t in minibatch])
( R ) (reward received) rewards = torch.FloatTensor([t[2] for t in minibatch])
( s' ) (next state) next_states = torch.FloatTensor([t[3] for t in minibatch])
( Q(s, a) ) (current Q-value) current_q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
( \max_{a'} Q(s', a') ) (best Q-value in next state) next_q = self.target_model(next_states).detach().max(1)[0]
Target Q-value: ( R + \gamma \max_{a'} Q(s', a') ) target = rewards + (1 - dones) * self.gamma * next_q
Loss function: ( (Q(s, a) - \text{target})^2 ) loss = F.mse_loss(current_q, target)
Gradient Descent Update self.optimizer.zero_grad(); loss.backward(); self.optimizer.step()

4. Key Takeaways

  • The main Q-network (self.model) predicts current Q-values.
  • The target Q-network (self.target_model) provides stable target Q-values.
  • The loss function ensures that the neural network minimizes the difference between current and target Q-values.
  • Using experience replay helps break correlations in training data, stabilizing learning.
  • The Bellman equation is implemented implicitly through these calculations.

This structured mapping should help you understand how the replay() function implements Q-learning in DQN! 🚀 Let me know if you have any questions!