cart pole:train - chunhualiao/public-docs GitHub Wiki
explain train_agent() of Agent class, map code to the algorithm described in text book style
train_agent()
Function in Agent
Class
Explanation of The train_agent()
function is responsible for training a Deep Q-Network (DQN) agent to solve the CartPole-v1 environment using reinforcement learning. It follows the standard DQN algorithm by integrating key components such as
- experience replay,
- ε-greedy exploration, and
- target network updates.
Below, we break down the function into algorithmic steps and map each part of the code to its corresponding textbook-style algorithm.
1. Algorithm: Deep Q-Learning with Experience Replay
The DQN training loop follows this algorithm:
Initialize:
- Initialize environment $env$ and agent $\mathcal{A}$ with neural network $Q_{\theta}$
- Set replay memory $\mathcal{D}$ (experience buffer)
- Initialize target network $Q_{\theta^-}$ with weights $\theta^- \gets \theta$
Training Loop (For each episode $e$):
4. Reset environment and obtain initial state $s_0$
5. Repeat until episode terminates:
- Select action $a_t$ using ε-greedy policy
- Execute action $a_t$, observe reward $R_t$, and next state $s_{t+1}$
- Store transition $(s_t, a_t, R_t, s_{t+1}, \text{done})$ in replay buffer $\mathcal{D}$
- Sample mini-batch from replay buffer
- Compute target Q-value using Bellman equation: $Q_{\text{target}} = R_t + \gamma \max_{a'} Q_{\theta^-}(s', a')$
- Update network parameters $\theta$ by minimizing loss:$L = (Q_{\theta}(s, a) - Q_{\text{target}})^2$
- Update target network periodically: $\theta^- \gets \theta$
- Decay exploration rate $\epsilon$
- Check termination condition (solved if average reward > threshold)
epsilon-greedy policy
a_t =
\begin{cases}
\text{random action}, & \text{with probability } \epsilon \\
\arg\max_a Q_{\theta}(s_t, a), & \text{otherwise}
\end{cases}
train_agent()
Code to Algorithm
2. Mapping Now, let’s break down the code and map it to the algorithm steps.
Step 1: Initialize Environment and Agent
env = gym.make('CartPole-v1')
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = Agent(state_size, action_size)
- Maps to: Step 1 & 2 of algorithm
- Environment ( env ) is created
- State size ( s ) and action size ( a ) are determined
- DQN Agent ( \mathcal{A} ) is initialized
Step 2: Initialize Training Parameters
batch_size = 64
target_update_freq = 100 # Update target network every 100 steps
scores = []
step_count = 0
target_avg_score = 450.0
- Maps to: Step 3
- Experience replay batch size set to 64
- Target network update frequency set to 100 steps
- Tracking episode scores and step count
Step 3: Training Loop - Episode Execution
while True: # Continue training indefinitely
e += 1
state, _ = env.reset()
total_reward = 0
done = False
truncated = False
- Maps to: Step 4
- Reset environment at beginning of each episode
- Set total reward for tracking performance
done
andtruncated
ensure that episodes terminate properly
Step 4: Select Action Using ε-Greedy Policy
action = agent.act(state)
- Maps to: Step 5 (Action Selection)
- Agent chooses action based on ε-greedy policy
- Exploration ((\epsilon)): Choose random action
- Exploitation ((1-\epsilon)): Choose best Q-value action
Step 5: Execute Action, Observe Next State & Reward
next_state, reward, done, truncated, _ = env.step(action)
- Maps to: Step 5 (Observe ( s', R ))
- Agent executes action in environment
- Observes next state ( s' ) and reward ( R )
- Checks if episode is done
Step 6: Store Experience in Replay Buffer
agent.remember(state, action, reward, next_state, done)
- Maps to: Step 5 (Store in Experience Buffer ( \mathcal{D} ))
- Agent stores transition ( (s, a, R, s', \text{done}) ) in replay buffer
Step 7: Sample Mini-Batch and Train DQN
agent.replay(batch_size)
- Maps to: Step 5 (Training using replay buffer)
- Samples a mini-batch from replay buffer
- Updates Q-values using Bellman equation
- Performs gradient descent to minimize loss
Step 8: Periodic Target Network Update
if step_count % target_update_freq == 0:
agent.update_target_model()
- Maps to: Step 5 (Target Network Update)
- Updates target network ( Q_{\theta^-} \gets Q_{\theta} ) every 100 steps
Step 9: Decay Exploration Rate (ε-Greedy)
agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay)
- Maps to: Step 6 (Decay ( \epsilon ))
- Reduces exploration as agent improves
Step 10: Check Performance & Save Model
if avg_score >= target_avg_score:
print(f"Solved in {e} episodes!")
agent.save("cartpole_dqn.pth")
break
- Maps to: Step 7 (Check Termination Condition)
- If average score reaches 450+, training stops
- Model is saved
3. Summary of Code-to-Algorithm Mapping
Algorithm Step | Code in train_agent() |
---|---|
Initialize environment & agent | env = gym.make('CartPole-v1') , agent = Agent(state_size, action_size) |
Initialize replay buffer | Done inside Agent class (self.memory = deque(maxlen=100000) ) |
Training loop (for each episode) | while True: (infinite loop until solved) |
Reset environment at episode start | state, _ = env.reset() |
Select action using ε-greedy policy | action = agent.act(state) |
Execute action & observe reward | next_state, reward, done, truncated, _ = env.step(action) |
Store transition in experience buffer | agent.remember(state, action, reward, next_state, done) |
Sample batch & train | agent.replay(batch_size) |
Update target network periodically | if step_count % target_update_freq == 0: agent.update_target_model() |
Decay exploration rate (ε) | agent.epsilon = max(agent.epsilon_min, agent.epsilon * agent.epsilon_decay) |
Check termination & save model | if avg_score >= target_avg_score: |
This structured breakdown provides textbook-style mapping of train_agent()
to Deep Q-Learning. 🚀 Let me know if you need further clarifications!
replay()
Function in Agent Class
Explanation of the The replay()
function in the DQN Agent is responsible for training the neural network using experiences stored in the replay buffer. It updates the Q-values using the Bellman equation, which is the foundation of Q-learning.
1. The Bellman Equation in DQN
The goal of Q-learning is to update the Q-value estimates to minimize the error between the predicted Q-value and the target Q-value. The update follows:
Q(s, a) \leftarrow Q(s, a) + \alpha \left[ R + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]
where:
- ( Q(s, a) ) → Current Q-value estimate for state ( s ) and action ( a )
- ( \alpha ) → Learning rate (handled by the optimizer in DQN)
- ( R ) → Reward received after taking action ( a ) in state ( s )
- ( \gamma ) → Discount factor (how much future rewards matter)
- ( \max_{a'} Q(s', a') ) → The maximum estimated Q-value for the next state ( s' ) (computed using the target model in stable DQN)
In DQN, instead of updating Q-values directly, we train a neural network to approximate ( Q(s, a) ).
replay()
Code to Bellman Equation
2. Mapping Now, let’s break down the key steps in the replay()
function and map them to the Bellman equation terms.
Step 1: Checking if Replay Buffer has Enough Data
if len(self.memory) < batch_size:
return
- Purpose: Ensures that training only starts when the replay buffer has accumulated enough experience.
Step 2: Sampling a Mini-Batch from Replay Buffer
minibatch = random.sample(self.memory, batch_size)
- Purpose: Randomly selects a batch of experiences (s, a, r, s', done) from memory to break temporal correlation, which stabilizes learning.
Step 3: Converting Mini-Batch to PyTorch Tensors
states = torch.FloatTensor([t[0] for t in minibatch]).to(device)
actions = torch.LongTensor([t[1] for t in minibatch]).to(device)
rewards = torch.FloatTensor([t[2] for t in minibatch]).to(device)
next_states = torch.FloatTensor([t[3] for t in minibatch]).to(device)
dones = torch.FloatTensor([t[4] for t in minibatch]).to(device)
- Mapping to Bellman equation terms:
states
→ ( s ) (current state)actions
→ ( a ) (action taken)rewards
→ ( R ) (reward received)next_states
→ ( s' ) (next state)dones
→ Terminal state indicator (used to stop considering future rewards when an episode ends)
self.model
)
Step 4: Computing Current Q-Values from Main Model (current_q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1)
- Maps to: ( Q(s, a) ) (current Q-value estimate)
- Breakdown:
self.model(states)
: Predicts Q-values for all actions in the current state..gather(1, actions.unsqueeze(1))
: Selects the Q-value corresponding to the action that was taken.
self.target_model
)
Step 5: Computing Target Q-Values Using the Target Model (next_q = self.target_model(next_states).detach().max(1)[0]
target = rewards + (1 - dones) * self.gamma * next_q
- Maps to: ( R + \gamma \max_{a'} Q(s', a') ) (target Q-value)
- Breakdown:
self.target_model(next_states)
: Predicts Q-values for all actions in the next state ( s' )..detach()
: Ensures that gradients are not computed for the target model (stabilizing updates)..max(1)[0]
: Selects the maximum Q-value over all possible actions ( a' ) in ( s' ).(1 - dones) * self.gamma * next_q
: Ensures that if the episode has ended (done = 1
), the future reward is ignored.
Step 6: Computing Loss and Updating the Model
loss = F.mse_loss(current_q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
- Maps to: Minimizing the loss function
L = \left( Q(s, a) - (R + \gamma \max_{a'} Q(s', a')) \right)^2
- Breakdown:
F.mse_loss(current_q, target)
: Computes Mean Squared Error (MSE) between current Q-values and target Q-values.self.optimizer.zero_grad()
: Clears previous gradients.loss.backward()
: Computes gradients for the neural network.self.optimizer.step()
: Updates network weights using gradient descent.
3. Summary of Mappings
Bellman Equation Term | Code in replay() |
---|---|
( s ) (current state) | states = torch.FloatTensor([t[0] for t in minibatch]) |
( a ) (action taken) | actions = torch.LongTensor([t[1] for t in minibatch]) |
( R ) (reward received) | rewards = torch.FloatTensor([t[2] for t in minibatch]) |
( s' ) (next state) | next_states = torch.FloatTensor([t[3] for t in minibatch]) |
( Q(s, a) ) (current Q-value) | current_q = self.model(states).gather(1, actions.unsqueeze(1)).squeeze(1) |
( \max_{a'} Q(s', a') ) (best Q-value in next state) | next_q = self.target_model(next_states).detach().max(1)[0] |
Target Q-value: ( R + \gamma \max_{a'} Q(s', a') ) | target = rewards + (1 - dones) * self.gamma * next_q |
Loss function: ( (Q(s, a) - \text{target})^2 ) | loss = F.mse_loss(current_q, target) |
Gradient Descent Update | self.optimizer.zero_grad(); loss.backward(); self.optimizer.step() |
4. Key Takeaways
- The main Q-network (
self.model
) predicts current Q-values. - The target Q-network (
self.target_model
) provides stable target Q-values. - The loss function ensures that the neural network minimizes the difference between current and target Q-values.
- Using experience replay helps break correlations in training data, stabilizing learning.
- The Bellman equation is implemented implicitly through these calculations.
This structured mapping should help you understand how the replay()
function implements Q-learning in DQN! 🚀 Let me know if you have any questions!