PPO:implementation - chunhualiao/public-docs GitHub Wiki
Let's break down the simple PPO implementation in PyTorch and explain how each part maps to the PPO objective function.
1. PPO Objective Function Recap
The core objective function for PPO is:
J(\theta) = \mathbb{E}_t \left[ \min \left( r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t \right) \right]
Where:
- $r_t(\theta)$ is the probability ratio between the new and old policy:
r_t(\theta) = \frac{\pi_{\theta}(a_t | s_t)}{\pi_{\theta_{\text{old}}}(a_t | s_t)}
- $A_t$ is the advantage function, measuring whether an action was better or worse than expected.
- Clipping ensures policy updates stay within $[1-\epsilon, 1+\epsilon]$ for stability.
2. Code Implementation and Mapping to PPO Objective
Here’s the PPO implementation in PyTorch with detailed explanations:
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np
# ===========================
# 1. Define the Policy (Actor)
# ===========================
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_dim)
self.softmax = nn.Softmax(dim=-1) # Outputs a probability distribution
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
return self.softmax(self.fc3(x)) # Probability distribution over actions
# Maps to: π_θ(a_t | s_t) → Policy output for action selection
# ===========================
# 2. Define the Value (Critic)
# ===========================
class ValueNetwork(nn.Module):
def __init__(self, state_dim):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1) # Single output (state value estimate)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
return self.fc3(x) # Estimated value of the state
# Maps to: V_ϕ(s_t) → Critic estimates how good the state is
# ===========================
# 3. Define the PPO Agent
# ===========================
class PPOAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=0.2):
self.policy_net = PolicyNetwork(state_dim, action_dim)
self.value_net = ValueNetwork(state_dim)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
self.gamma = gamma # Discount factor
self.epsilon = epsilon # Clipping parameter
# =======================
# Compute Advantage A_t
# =======================
def compute_advantage(self, rewards, values, gamma):
advantages = []
G = 0
for r, v in zip(reversed(rewards), reversed(values)):
G = r + gamma * G # Compute discounted returns
advantages.insert(0, G - v) # A_t = R_t - V(s_t)
return torch.tensor(advantages, dtype=torch.float32)
# Maps to: A_t = R_t - V_ϕ(s_t) → Computes the advantage function
# =======================
# PPO Policy and Value Update
# =======================
def update(self, states, actions, rewards, old_probs, values):
advantages = self.compute_advantage(rewards, values, self.gamma)
states = torch.tensor(states, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.int64)
old_probs = torch.tensor(old_probs, dtype=torch.float32)
# Compute new action probabilities
new_probs = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
# Compute probability ratio
ratio = new_probs / old_probs # r_t(θ) = π_θ(a_t | s_t) / π_θ_old(a_t | s_t)
# Maps to: r_t(θ) → Probability ratio between new and old policy
# Compute surrogate loss with clipping
clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
# Maps to:
# J(θ) = E[min(r_t(θ) A_t, clip(r_t(θ), 1-ε, 1+ε) A_t)]
# Ensures policy updates are within a safe range
# Compute value loss
value_targets = advantages + torch.tensor(values, dtype=torch.float32)
value_loss = nn.MSELoss()(self.value_net(states).squeeze(), value_targets)
# Maps to: L_V(ϕ) = (V_ϕ(s_t) - R_t)² → Trains the critic
# Optimize policy network
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
# Optimize value network
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
# ===========================
# 4. Training Loop
# ===========================
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)
for episode in range(1000):
state, _ = env.reset()
done = False
rewards, states, actions, old_probs, values = [], [], [], [], []
while not done:
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
probs = agent.policy_net(state_tensor).detach() # π_θ(a_t | s_t)
action = torch.multinomial(probs, 1).item() # Sample action
value = agent.value_net(state_tensor).item() # V_ϕ(s_t)
next_state, reward, done, _, _ = env.step(action)
# Store experience
states.append(state)
actions.append(action)
rewards.append(reward)
old_probs.append(probs[0, action].item())
values.append(value)
state = next_state
agent.update(states, actions, rewards, old_probs, values)
env.close()
3. How Each Part of the Code Maps to the PPO Objective Function
Code Section | Mathematical Equivalent | Explanation |
---|---|---|
self.policy_net.forward(state) |
$\pi_{\theta}(a_t|s_t)$ | Policy network outputs probabilities of actions. |
self.value_net.forward(state) |
$V_{\phi}(s_t)$ | Value network estimates how good a state is. |
ratio = new_probs / old_probs |
$r_t(\theta)=\frac{\pi_{\theta}(a_t|s_t)}{\pi_{\theta_{\text{old}}}(a_t|s_t)}$ | Computes how much the policy has changed. |
advantages = self.compute_advantage(rewards, values, self.gamma) |
$A_t=R_t - V_{\phi}(s_t)$ | Computes advantage estimates. |
torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon) |
$\text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)$ | Ensures stability by clipping updates. |
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean() |
$\min(r_t(\theta) A_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon) A_t)$ | PPO loss function ensuring controlled updates. |
value_loss = (V_{\phi}(s_t) - R_t)^2 |
$L_V(\phi) = (V_{\phi}(s_t) - R_t)^2$ | Updates the critic using MSE loss. |
Conclusion
- The policy network (actor) learns which actions to take.
- The value network (critic) learns how good states are.
- The clipped surrogate loss function ensures safe and stable updates.
Would you like a visual diagram to further clarify? 🚀
VERSION 2
import torch
import torch.nn as nn
import torch.optim as optim
import gym
import numpy as np
# Define the policy network
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, action_dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
return self.softmax(self.fc3(x))
# Define the value network (critic)
class ValueNetwork(nn.Module):
def __init__(self, state_dim):
super(ValueNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, 64)
self.fc2 = nn.Linear(64, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, state):
x = torch.relu(self.fc1(state))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# PPO Agent
class PPOAgent:
def __init__(self, state_dim, action_dim, lr=0.001, gamma=0.99, epsilon=0.2):
self.policy_net = PolicyNetwork(state_dim, action_dim)
self.value_net = ValueNetwork(state_dim)
self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=lr)
self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=lr)
self.gamma = gamma
self.epsilon = epsilon
def compute_advantage(self, rewards, values, gamma):
advantages = []
G = 0
for r, v in zip(reversed(rewards), reversed(values)):
G = r + gamma * G
advantages.insert(0, G - v)
return torch.tensor(advantages, dtype=torch.float32)
def update(self, states, actions, rewards, old_probs, values):
advantages = self.compute_advantage(rewards, values, self.gamma)
states = torch.tensor(states, dtype=torch.float32)
actions = torch.tensor(actions, dtype=torch.int64)
old_probs = torch.tensor(old_probs, dtype=torch.float32)
# Compute new action probabilities
new_probs = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
# Compute probability ratio
ratio = new_probs / old_probs
# Compute surrogate loss with clipping
clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)
policy_loss = -torch.min(ratio * advantages, clipped_ratio * advantages).mean()
# Compute value loss
value_targets = advantages + torch.tensor(values, dtype=torch.float32)
value_loss = nn.MSELoss()(self.value_net(states).squeeze(), value_targets)
# Optimize policy network
self.policy_optimizer.zero_grad()
policy_loss.backward()
self.policy_optimizer.step()
# Optimize value network
self.value_optimizer.zero_grad()
value_loss.backward()
self.value_optimizer.step()
# Training loop
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
agent = PPOAgent(state_dim, action_dim)
for episode in range(1000):
state, _ = env.reset()
done = False
rewards, states, actions, old_probs, values = [], [], [], [], []
while not done:
state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
probs = agent.policy_net(state_tensor).detach()
action = torch.multinomial(probs, 1).item()
value = agent.value_net(state_tensor).item()
next_state, reward, done, _, _ = env.step(action)
# Store experience
states.append(state)
actions.append(action)
rewards.append(reward)
old_probs.append(probs[0, action].item())
values.append(value)
state = next_state
agent.update(states, actions, rewards, old_probs, values)
env.close()