veRL:training - chunhualiao/public-docs GitHub Wiki

veRL

Training Algorithm Overview

What is the Training Algorithm?

The training algorithm used is Proximal Policy Optimization (PPO). PPO is a reinforcement learning algorithm that improves the policy by taking small, controlled steps to increase the expected reward. It ensures that the new policy doesn't deviate too far from the old policy, which helps stabilize training.

Where is it Implemented?

The core PPO implementation is found in:

  • verl/trainer/ppo/ray_trainer.py (RayPPOTrainer class, fit() method)
  • verl/trainer/ppo/core_algos.py (helper functions for PPO components)

These files contain the key logic for policy optimization, value estimation, advantage calculation, and updates.

PPO Algorithm

1. Initialization

  • Initialize policy network (actor) and value network (critic).
  • (Optional) Initialize reference policy for stability.

2. Data Collection (Rollout)

  • Use the policy to interact with the environment and collect experiences:
    • Prompts
    • Generated responses
    • Log probabilities of actions under the current policy
    • Log probabilities under the reference policy (if used)
    • Rewards (or scores) for responses

3. Advantage Estimation

  • Compute the advantage function for each state-action pair:
    • Generalized Advantage Estimation (GAE)
    • GRPO (Generalized Return decomposition for PPO with Outcome supervision)

4. Policy Update (Actor Update)

  • Optimize the PPO-clip objective function:

    L = \min \left( \text{ratio} \cdot A, \text{clip}(\text{ratio}, 1 - \epsilon, 1 + \epsilon) \cdot A \right)
    

    where:

    • ratio = exp(log_prob_new - log_prob_old)
    • A = advantage estimate
    • clip_epsilon controls clipping range (e.g., 0.2)

5. Value Function Update (Critic Update)

  • Optimize clipped value loss:

    V_{\text{loss}} = \max \left( (V_{\text{new}} - V_{\text{target}})^2, \left( \text{clip}(V_{\text{new}}, V_{\text{old}} - \epsilon, V_{\text{old}} + \epsilon) - V_{\text{target}} \right)^2 \right)
    

6. KL Divergence Control (Optional)

  • Add a KL divergence penalty to constrain policy updates.

7. Repeat

  • Repeat steps 2-6 until convergence.

Code Implementation

Core PPO logic in ray_trainer.py (RayPPOTrainer class)

The fit() method contains the main training loop.

1. Data Collection

gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
batch = batch.union(gen_batch_output)
  • Uses actor_rollout_wg to generate sequences (responses) and combines them with the batch.

2. Reference Policy Log Probabilities (Optional)

if self.use_reference_policy:
    ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
    batch = batch.union(ref_log_prob)
  • Computes log probabilities under a reference policy.

3. Value Estimation

if self.use_critic:
    values = self.critic_wg.compute_values(batch)
    batch = batch.union(values)
  • Uses the critic network to compute value estimates.

4. Reward and Advantage Calculation

reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

if not self.config.actor_rollout_ref.actor.use_kl_loss:
    batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty)
    metrics.update(kl_metrics)
else:
    batch.batch['token_level_rewards'] = batch.batch['token_level_scores']

batch = compute_advantage(batch,
                          adv_estimator=self.config.algorithm.adv_estimator,
                          gamma=self.config.algorithm.gamma,
                          lam=self.config.algorithm.lam,
                          num_repeat=self.config.actor_rollout_ref.rollout.n)
  • Computes rewards
  • Applies KL penalty (if applicable)
  • Computes advantages using compute_advantage

5. Critic Update

if self.use_critic:
    critic_output = self.critic_wg.update_critic(batch)
    critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
    metrics.update(critic_output_metrics)
  • Updates the critic network using the CriticWorker.

6. Actor Update

if self.config.trainer.critic_warmup <= self.global_steps:
    actor_output = self.actor_rollout_wg.update_actor(batch)
    actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
    metrics.update(actor_output_metrics)
  • Updates the actor network using ActorRolloutRefWorker.

7. KL Divergence Control

  • apply_kl_penalty() and kl_ctrl handle KL divergence control.

Helper Functions in core_algos.py

This file contains:

  • compute_gae_advantage_return: GAE-based advantage estimation
  • compute_grpo_outcome_advantage: GRPO advantage calculation
  • compute_policy_loss: PPO-clip loss function
  • compute_value_loss: Value function loss
  • compute_entropy_loss: Optional entropy regularization
  • AdaptiveKLController, FixedKLController: KL penalty management

Mapping the Code to the Algorithm

The PPO algorithm steps map directly to the code:

Algorithm Step Code Location
Initialization RayPPOTrainer.__init__()
Data Collection self.actor_rollout_wg.generate_sequences()
Reference Policy Log Probabilities self.ref_policy_wg.compute_ref_log_prob()
Value Estimation self.critic_wg.compute_values()
Reward Calculation self.reward_fn(batch)
Advantage Computation compute_advantage(batch)
Critic Update self.critic_wg.update_critic()
Actor Update self.actor_rollout_wg.update_actor()
KL Penalty apply_kl_penalty(batch)

Conclusion

The RayPPOTrainer class orchestrates the PPO training, leveraging helper functions from core_algos.py for advantage estimation, loss computation, and KL control. The configuration file (config) determines hyperparameters, governing how the training loop executes.

This structured approach ensures stability and performance improvements in policy learning.