TinyZero:training - chunhualiao/public-docs GitHub Wiki

TinyZero

Implements the PPO algorithm with:

  • Experience collection through batch generation.
  • Advantage estimation (GAE or GRPO).
  • Policy updates with KL divergence control.
  • Value function updates.
  • Sequence length balancing for efficient parallel training.

Model Weight Adjustment in veRL During RL Training

Based on the code examined in core_algos.py and fsdp_workers.py, the model weights in veRL are adjusted through PPO (Proximal Policy Optimization) training. Below is an explanation of the key components involved in this process:


1. Policy Gradient Loss

The policy gradient loss is computed using the PPO clipped objective to ensure stable updates.

def compute_policy_loss(old_log_prob, log_prob, advantages, eos_mask, cliprange):
    negative_approx_kl = log_prob - old_log_prob
    ratio = torch.exp(negative_approx_kl)
    
    # PPO clipped objective
    pg_losses = -advantages * ratio
    pg_losses2 = -advantages * torch.clamp(ratio, 1.0 - cliprange, 1.0 + cliprange)
    
    # Take max to implement PPO's pessimistic bound
    pg_loss = verl_F.masked_mean(torch.max(pg_losses, pg_losses2), eos_mask)

2. Value Function Loss

The value function loss ensures the critic network accurately estimates the expected returns.

def compute_value_loss(vpreds, returns, values, eos_mask, cliprange_value):
    vpredclipped = verl_F.clip_by_value(vpreds, values - cliprange_value, values + cliprange_value)
    vf_losses1 = (vpreds - returns)**2
    vf_losses2 = (vpredclipped - returns)**2
    vf_loss = 0.5 * verl_F.masked_mean(torch.max(vf_losses1, vf_losses2), eos_mask)

3. KL Penalty

A KL penalty is applied to prevent the policy from deviating too much from the reference policy.

def kl_penalty(logprob, ref_logprob, kl_penalty):
    if kl_penalty == "kl":
        return logprob - ref_logprob
    elif kl_penalty == "abs":
        return (logprob - ref_logprob).abs()

4. Weight Updates

The ActorRolloutRefWorker class handles model weight updates using FSDP (Fully Sharded Data Parallel).

def update_actor(self, data: DataProto):
    # Load parameters and optimizer state if offloaded
    if self._is_offload_param:
        load_fsdp_param_and_grad(module=self.actor_module_fsdp)
    if self._is_offload_optimizer:
        load_fsdp_optimizer(optimizer=self.actor_optimizer)
        
    # Update policy using PPO
    metrics = self.actor.update_policy(data=data)
    
    # Step learning rate scheduler
    self.actor_lr_scheduler.step()

5. Optimization Process

  • Uses the AdamW optimizer with configurable learning rate and weight decay.
  • Supports learning rate warmup through a scheduler.
  • Implements gradient clipping and optimization in micro-batches.
  • Handles distributed training through FSDP.

6. Training Flow

The training process follows these steps:

  1. Generate sequences using the current policy.
  2. Compute advantages using GAE (Generalized Advantage Estimation) or GRPO.
  3. Calculate policy and value losses.
  4. Apply KL penalty to prevent large policy updates.
  5. Update model weights through backpropagation.
  6. Track metrics like policy loss, value loss, and KL divergence.

Techniques to Stabilize Training

The system employs several techniques to ensure stable and efficient training:

  • PPO clipping to limit policy updates.
  • Value function clipping.
  • KL divergence control.
  • Gradient clipping.
  • Learning rate scheduling.
  • Distributed training optimization through FSDP.

Conclusion

This implementation allows for efficient training of large language models while maintaining stable updates through PPO's conservative policy iteration approach. The use of FSDP and advanced optimization techniques ensures scalability and robustness in distributed training environments.