TinyZero:training - chunhualiao/public-docs GitHub Wiki
Implements the PPO algorithm with:
- Experience collection through batch generation.
- Advantage estimation (GAE or GRPO).
- Policy updates with KL divergence control.
- Value function updates.
- Sequence length balancing for efficient parallel training.
Model Weight Adjustment in veRL During RL Training
Based on the code examined in core_algos.py
and fsdp_workers.py
, the model weights in veRL are adjusted through PPO (Proximal Policy Optimization) training. Below is an explanation of the key components involved in this process:
1. Policy Gradient Loss
The policy gradient loss is computed using the PPO clipped objective to ensure stable updates.
def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
negative_approx_kl = log_prob - old_log_prob
ratio = torch.exp(negative_approx_kl)
# PPO clipped objective
pg_losses = -advantages * ratio
pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
# Take max to implement PPO's pessimistic bound
pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)
2. Value Function Loss
The value function loss ensures the critic network accurately estimates the expected returns.
def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
vf_losses1 = (vpreds - returns)**2
vf_losses2 = (vpredclipped - returns)**2
vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)
3. KL Penalty
A KL penalty is applied to prevent the policy from deviating too much from the reference policy.
def kl_penalty(logprob, ref_logprob, kl_penalty):
if kl_penalty == "kl":
return logprob - ref_logprob
elif kl_penalty == "abs":
return (logprob - ref_logprob).abs()
4. Weight Updates
The ActorRolloutRefWorker
class handles model weight updates using FSDP (Fully Sharded Data Parallel).
def update_actor(self, data: DataProto):
# Load parameters and optimizer state if offloaded
if self._is_offload_param:
load_fsdp_param_and_grad(module=self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer)
# Update policy using PPO
metrics = self.actor.update_policy(data=data)
# Step learning rate scheduler
self.actor_lr_scheduler.step()
5. Optimization Process
- Uses the AdamW optimizer with configurable learning rate and weight decay.
- Supports learning rate warmup through a scheduler.
- Implements gradient clipping and optimization in micro-batches.
- Handles distributed training through FSDP.
6. Training Flow
The training process follows these steps:
- Generate sequences using the current policy.
- Compute advantages using GAE (Generalized Advantage Estimation) or GRPO.
- Calculate policy and value losses.
- Apply KL penalty to prevent large policy updates.
- Update model weights through backpropagation.
- Track metrics like policy loss, value loss, and KL divergence.
Techniques to Stabilize Training
The system employs several techniques to ensure stable and efficient training:
- PPO clipping to limit policy updates.
- Value function clipping.
- KL divergence control.
- Gradient clipping.
- Learning rate scheduling.
- Distributed training optimization through FSDP.
Conclusion
This implementation allows for efficient training of large language models while maintaining stable updates through PPO's conservative policy iteration approach. The use of FSDP and advanced optimization techniques ensures scalability and robustness in distributed training environments.