cart pole:agent - chunhualiao/public-docs GitHub Wiki
The Role of self.model
and self.target_model
in DQN
The Agent in the code has two models, self.model
and self.target_model
, to address a critical stability issue in Deep Q-Networks (DQNs) called instability due to moving targets. Let's break down why this is necessary.
self.model
and self.target_model
in Agent
Class
Comparison of Aspect | self.model |
self.target_model |
---|---|---|
Purpose | Main Q-Network, learns from experience and updates via gradient descent. | Target Q-Network, provides stable Q-value estimates to prevent instability. |
Where Used in Code | - agent.act(state) : Used for selecting actions during exploration and exploitation. - self.model(states).gather(...) : Used to compute the current Q-values during training. - self.optimizer.step() : Updated in every training step via backpropagation. |
- self.target_model(next_states).detach().max(1)[0] : Used to compute target Q-values. - agent.update_target_model() : Synchronizes weights with self.model every 100 steps. |
Update Frequency | Updated in every training step. | Updated every 100 steps (or periodically). |
Impact on Training | Adapts quickly to new experiences, but if used for targets, can cause instability. | Provides stable target values, reducing oscillations and improving convergence. |
The Problem: Unstable Targets in Q-Learning
In standard Q-learning, the update rule for the Q-value of a state-action pair ((s, a)) is based on the Bellman equation:
Q(s, a) = Q(s, a) + \alpha \times [R + \gamma \times \max_{a'} Q(s', a') - Q(s, a)]
Where:
- ( Q(s, a) ) is the current Q-value estimate for state ( s ) and action ( a ).
- ( \alpha ) is the learning rate.
- ( R ) is the immediate reward.
- ( \gamma ) is the discount factor.
- ( s' ) is the next state.
- ( \max_{a'} Q(s', a') ) is the maximum Q-value among all possible actions ( a' ) in the next state ( s' ). This is the target value that we want to move our current ( Q(s, a) ) towards.
When we use a neural network to approximate the Q-function (as in DQN), the Q-values ( Q(s, a) ) are no longer stored in a table but are outputs of the neural network. If we were to directly use the same network to both:
- Estimate the current Q-values ( Q(s, a) )
- Generate the target Q-values ( \max_{a'} Q(s', a') )
...during training, we run into a problem: the target values are constantly shifting because the same network is being updated in each training step.
Imagine trying to hit a moving target that is moving based on your own movements—it becomes very difficult to learn and stabilize. In DQN, this leads to oscillations and instability in training because the network is trying to adjust its predictions towards a target that is itself changing.
The Solution: Target Network
To stabilize training, DQN introduces a target network (self.target_model
). This is a separate neural network that is a copy of the main Q-network (self.model
), but it is not updated as frequently as the main network.
How It Works:
self.model
)
Main Network (This is the network that is trained in each step using gradient descent. It's used to:
- Select actions during exploration and exploitation (
agent.act(state)
). - Estimate the Q-values for the current state in the loss calculation (
self.model(states).gather(...)
).
self.target_model
)
Target Network (This network is used to generate the target Q-values in the DQN update rule:
self.target_model(next_states).detach().max(1)[0]