cart pole:agent - chunhualiao/public-docs GitHub Wiki

cart pole>

The Role of self.model and self.target_model in DQN

The Agent in the code has two models, self.model and self.target_model, to address a critical stability issue in Deep Q-Networks (DQNs) called instability due to moving targets. Let's break down why this is necessary.

Comparison of self.model and self.target_model in Agent Class

Aspect self.model self.target_model
Purpose Main Q-Network, learns from experience and updates via gradient descent. Target Q-Network, provides stable Q-value estimates to prevent instability.
Where Used in Code - agent.act(state): Used for selecting actions during exploration and exploitation. - self.model(states).gather(...): Used to compute the current Q-values during training. - self.optimizer.step(): Updated in every training step via backpropagation. - self.target_model(next_states).detach().max(1)[0]: Used to compute target Q-values. - agent.update_target_model(): Synchronizes weights with self.model every 100 steps.
Update Frequency Updated in every training step. Updated every 100 steps (or periodically).
Impact on Training Adapts quickly to new experiences, but if used for targets, can cause instability. Provides stable target values, reducing oscillations and improving convergence.

The Problem: Unstable Targets in Q-Learning

In standard Q-learning, the update rule for the Q-value of a state-action pair ((s, a)) is based on the Bellman equation:

Q(s, a) = Q(s, a) + \alpha \times [R + \gamma \times \max_{a'} Q(s', a') - Q(s, a)]

Where:

  • ( Q(s, a) ) is the current Q-value estimate for state ( s ) and action ( a ).
  • ( \alpha ) is the learning rate.
  • ( R ) is the immediate reward.
  • ( \gamma ) is the discount factor.
  • ( s' ) is the next state.
  • ( \max_{a'} Q(s', a') ) is the maximum Q-value among all possible actions ( a' ) in the next state ( s' ). This is the target value that we want to move our current ( Q(s, a) ) towards.

When we use a neural network to approximate the Q-function (as in DQN), the Q-values ( Q(s, a) ) are no longer stored in a table but are outputs of the neural network. If we were to directly use the same network to both:

  1. Estimate the current Q-values ( Q(s, a) )
  2. Generate the target Q-values ( \max_{a'} Q(s', a') )

...during training, we run into a problem: the target values are constantly shifting because the same network is being updated in each training step.

Imagine trying to hit a moving target that is moving based on your own movements—it becomes very difficult to learn and stabilize. In DQN, this leads to oscillations and instability in training because the network is trying to adjust its predictions towards a target that is itself changing.

The Solution: Target Network

To stabilize training, DQN introduces a target network (self.target_model). This is a separate neural network that is a copy of the main Q-network (self.model), but it is not updated as frequently as the main network.

How It Works:

Main Network (self.model)

This is the network that is trained in each step using gradient descent. It's used to:

  • Select actions during exploration and exploitation (agent.act(state)).
  • Estimate the Q-values for the current state in the loss calculation (self.model(states).gather(...)).

Target Network (self.target_model)

This network is used to generate the target Q-values in the DQN update rule:

self.target_model(next_states).detach().max(1)[0]