PPO:implementation - chunhualiao/public-docs GitHub Wiki

PPO

Let's break down the simple PPO implementation in PyTorch and explain how each part maps to the PPO objective function.


1. PPO Objective Function Recap

The core objective function for PPO is:

J(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]

Where:

  • $r_t(\theta)$ is the probability ratio between the new and old policy:
  r_t(\theta) = \frac{\pi_{\theta}(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}
  • $A_t$ is the advantage function, measuring whether an action was better or worse than expected.
  • Clipping ensures policy updates stay within $[1-\epsilon, 1+\epsilon]$ for stability.

2. Code Implementation and Mapping to PPO Objective

Here’s the PPO implementation in PyTorch with detailed explanations:

import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np

# ===========================
# 1. Define the Policy (Actor)
# ===========================
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        self.softmax = nn.Softmax(dim=-1)  # Outputs a probability distribution

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.softmax(self.fc3(x))  # Probability distribution over actions

# Maps to: π_θ(a_t | s_t) → Policy output for action selection

# ===========================
# 2. Define the Value (Critic)
# ===========================
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)  # Single output (state value estimate)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)  # Estimated value of the state

# Maps to: V_ϕ(s_t) → Critic estimates how good the state is

# ===========================
# 3. Define the PPO Agent
# ===========================
class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=0.2):
        self.policy_net = PolicyNetwork(state_dim, action_dim)
        self.value_net = ValueNetwork(state_dim)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Clipping parameter

    # =======================
    # Compute Advantage A_t
    # =======================
    def compute_advantage(self, rewards, values, gamma):
        advantages = []
        G = 0
        for r, v in zip(reversed(rewards), reversed(values)):
            G = r + gamma * G  # Compute discounted returns
            advantages.insert(0, G - v)  # A_t = R_t - V(s_t)
        return torch.tensor(advantages, dtype=torch.float32)

    # Maps to: A_t = R_t - V_ϕ(s_t) → Computes the advantage function

    # =======================
    # PPO Policy and Value Update
    # =======================
    def update(self, states, actions, rewards, old_probs, values):
        advantages = self.compute_advantage(rewards, values, self.gamma)
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.int64)
        old_probs = torch.tensor(old_probs, dtype=torch.float32)

        # Compute new action probabilities
        new_probs = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()

        # Compute probability ratio
        ratio = new_probs / old_probs  # r_t(θ) = π_θ(a_t | s_t) / π_θ_old(a_t | s_t)

        # Maps to: r_t(θ) → Probability ratio between new and old policy

        # Compute surrogate loss with clipping
        clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)
        policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()

        # Maps to: 
        # J(θ) = E[min(r_t(θ) A_t, clip(r_t(θ), 1-ε, 1+ε) A_t)]
        # Ensures policy updates are within a safe range

        # Compute value loss
        value_targets = advantages + torch.tensor(values, dtype=torch.float32)
        value_loss = nn.MSELoss()(self.value_net(states).squeeze(), value_targets)

        # Maps to: L_V(ϕ) = (V_ϕ(s_t) - R_t)² → Trains the critic

        # Optimize policy network
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Optimize value network
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

# ===========================
# 4. Training Loop
# ===========================
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)

for episode in range(1000):
    state, _ = env.reset()
    done = False
    rewards, states, actions, old_probs, values = [], [], [], [], []

    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        probs = agent.policy_net(state_tensor).detach()  # π_θ(a_t | s_t)
        action = torch.multinomial(probs, 1).item()  # Sample action
        value = agent.value_net(state_tensor).item()  # V_ϕ(s_t)

        next_state, reward, done, _, _ = env.step(action)

        # Store experience
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        old_probs.append(probs[0, action].item())
        values.append(value)

        state = next_state

    agent.update(states, actions, rewards, old_probs, values)

env.close()

3. How Each Part of the Code Maps to the PPO Objective Function

Code Section Mathematical Equivalent Explanation
self.policy_net.forward(state) $\pi_{\theta}(a_t|s_t)$ Policy network outputs probabilities of actions.
self.value_net.forward(state) $V_{\phi}(s_t)$ Value network estimates how good a state is.
ratio = new_probs / old_probs $r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ Computes how much the policy has changed.
advantages = self.compute_advantage(rewards, values, self.gamma) $A_t=R_t - V_{\phi}(s_t)$ Computes advantage estimates.
torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) $\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)$ Ensures stability by clipping updates.
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean() $\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)$ PPO loss function ensuring controlled updates.
value_loss = (V_{\phi}(s_t) - R_t)^2 $L_V(\phi) = (V_{\phi}(s_t) - R_t)^2$ Updates the critic using MSE loss.

Conclusion

  • The policy network (actor) learns which actions to take.
  • The value network (critic) learns how good states are.
  • The clipped surrogate loss function ensures safe and stable updates.

Would you like a visual diagram to further clarify? 🚀

VERSION 2

import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np

# Define the policy network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.softmax(self.fc3(x))

# Define the value network (critic)
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# PPO Agent
class PPOAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=0.2):
        self.policy_net = PolicyNetwork(state_dim, action_dim)
        self.value_net = ValueNetwork(state_dim)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon

    def compute_advantage(self, rewards, values, gamma):
        advantages = []
        G = 0
        for r, v in zip(reversed(rewards), reversed(values)):
            G = r + gamma * G
            advantages.insert(0, G - v)
        return torch.tensor(advantages, dtype=torch.float32)

    def update(self, states, actions, rewards, old_probs, values):
        advantages = self.compute_advantage(rewards, values, self.gamma)
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.int64)
        old_probs = torch.tensor(old_probs, dtype=torch.float32)

        # Compute new action probabilities
        new_probs = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()

        # Compute probability ratio
        ratio = new_probs / old_probs

        # Compute surrogate loss with clipping
        clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)
        policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()

        # Compute value loss
        value_targets = advantages + torch.tensor(values, dtype=torch.float32)
        value_loss = nn.MSELoss()(self.value_net(states).squeeze(), value_targets)

        # Optimize policy network
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        # Optimize value network
        self.value_optimizer.zero_grad()
        value_loss.backward()
        self.value_optimizer.step()

# Training loop
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)

for episode in range(1000):
    state, _ = env.reset()
    done = False
    rewards, states, actions, old_probs, values = [], [], [], [], []

    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        probs = agent.policy_net(state_tensor).detach()
        action = torch.multinomial(probs, 1).item()
        value = agent.value_net(state_tensor).item()

        next_state, reward, done, _, _ = env.step(action)

        # Store experience
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        old_probs.append(probs[0, action].item())
        values.append(value)

        state = next_state

    agent.update(states, actions, rewards, old_probs, values)

env.close()