veRL:training - chunhualiao/public-docs GitHub Wiki
Training Algorithm Overview
What is the Training Algorithm?
The training algorithm used is Proximal Policy Optimization (PPO). PPO is a reinforcement learning algorithm that improves the policy by taking small, controlled steps to increase the expected reward. It ensures that the new policy doesn't deviate too far from the old policy, which helps stabilize training.
Where is it Implemented?
The core PPO implementation is found in:
verl/trainer/ppo/ray_trainer.py
(RayPPOTrainer class,fit()
method)verl/trainer/ppo/core_algos.py
(helper functions for PPO components)
These files contain the key logic for policy optimization, value estimation, advantage calculation, and updates.
PPO Algorithm
1. Initialization
- Initialize policy network (actor) and value network (critic).
- (Optional) Initialize reference policy for stability.
2. Data Collection (Rollout)
- Use the policy to interact with the environment and collect experiences:
- Prompts
- Generated responses
- Log probabilities of actions under the current policy
- Log probabilities under the reference policy (if used)
- Rewards (or scores) for responses
3. Advantage Estimation
- Compute the advantage function for each state-action pair:
- Generalized Advantage Estimation (GAE)
- GRPO (Generalized Return decomposition for PPO with Outcome supervision)
4. Policy Update (Actor Update)
-
Optimize the PPO-clip objective function:
L = \min \left( \text{ratio} \cdot A, \text{clip}(\text{ratio}, 1 - \epsilon, 1 + \epsilon) \cdot A \right)
where:
ratio
=exp(log_prob_new - log_prob_old)
A
= advantage estimateclip_epsilon
controls clipping range (e.g., 0.2)
5. Value Function Update (Critic Update)
-
Optimize clipped value loss:
V_{\text{loss}} = \max \left( (V_{\text{new}} - V_{\text{target}})^2, \left( \text{clip}(V_{\text{new}}, V_{\text{old}} - \epsilon, V_{\text{old}} + \epsilon) - V_{\text{target}} \right)^2 \right)
6. KL Divergence Control (Optional)
- Add a KL divergence penalty to constrain policy updates.
7. Repeat
- Repeat steps 2-6 until convergence.
Code Implementation
ray_trainer.py
(RayPPOTrainer class)
Core PPO logic in The fit()
method contains the main training loop.
1. Data Collection
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
batch = batch.union(gen_batch_output)
- Uses
actor_rollout_wg
to generate sequences (responses) and combines them with the batch.
2. Reference Policy Log Probabilities (Optional)
if self.use_reference_policy:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
- Computes log probabilities under a reference policy.
3. Value Estimation
if self.use_critic:
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
- Uses the critic network to compute value estimates.
4. Reward and Advantage Calculation
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor
if not self.config.actor_rollout_ref.actor.use_kl_loss:
batch, kl_metrics = apply_kl_penalty(batch, kl_ctrl=self.kl_ctrl, kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
batch = compute_advantage(batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n)
- Computes rewards
- Applies KL penalty (if applicable)
- Computes advantages using
compute_advantage
5. Critic Update
if self.use_critic:
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)
- Updates the critic network using the
CriticWorker
.
6. Actor Update
if self.config.trainer.critic_warmup <= self.global_steps:
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)
- Updates the actor network using
ActorRolloutRefWorker
.
7. KL Divergence Control
apply_kl_penalty()
andkl_ctrl
handle KL divergence control.
core_algos.py
Helper Functions in This file contains:
compute_gae_advantage_return
: GAE-based advantage estimationcompute_grpo_outcome_advantage
: GRPO advantage calculationcompute_policy_loss
: PPO-clip loss functioncompute_value_loss
: Value function losscompute_entropy_loss
: Optional entropy regularizationAdaptiveKLController
,FixedKLController
: KL penalty management
Mapping the Code to the Algorithm
The PPO algorithm steps map directly to the code:
Algorithm Step | Code Location |
---|---|
Initialization | RayPPOTrainer.__init__() |
Data Collection | self.actor_rollout_wg.generate_sequences() |
Reference Policy Log Probabilities | self.ref_policy_wg.compute_ref_log_prob() |
Value Estimation | self.critic_wg.compute_values() |
Reward Calculation | self.reward_fn(batch) |
Advantage Computation | compute_advantage(batch) |
Critic Update | self.critic_wg.update_critic() |
Actor Update | self.actor_rollout_wg.update_actor() |
KL Penalty | apply_kl_penalty(batch) |
Conclusion
The RayPPOTrainer class orchestrates the PPO training, leveraging helper functions from core_algos.py
for advantage estimation, loss computation, and KL control. The configuration file (config
) determines hyperparameters, governing how the training loop executes.
This structured approach ensures stability and performance improvements in policy learning.