09. PPO Theory
Proximal Policy Optimization (PPO) is the workhorse algorithm behind most production RLHF systems. It directly optimizes a policy to maximize expected reward while constraining updates to avoid catastrophic policy degradation.
The PPO objective extends the standard policy gradient with a clipped surrogate objective:
def ppo_objective(ratio, advantages, epsilon=0.2):
"""
ratio: pi_theta(a|s) / pi_theta_old(a|s) - probability ratio
advantages: estimated advantage of taking action a in state s
epsilon: clipping parameter (typically 0.1-0.2)
"""
# Unclipped objective
unclipped = ratio * advantages
# Clipped objective - prevents large updates
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
clipped = clipped_ratio * advantages
# Take the minimum of clipped and unclipped
# This makes the objective a lower bound on the true improvement
return -torch.min(unclipped, clipped).mean()
The clipping prevents the policy from changing too much in a single update. Without clipping, large policy updates can collapse the policy to a degenerate distribution.
KL penalty approach: An alternative to clipping is adding a KL penalty to the objective:
def ppo_objective_with_kl(log_probs, ref_log_probs, rewards, beta=0.1):
"""
log_probs: current policy log probs
ref_log_probs: reference (SFT) policy log probs
rewards: reward model scores
beta: KL coefficient
"""
# KL divergence penalty
kl = log_probs - ref_log_probs
kl_penalty = -beta * kl
# Combine with rewards
loss = -(rewards + kl_penalty).mean()
return loss
This is the approach used in TRL's PPOTrainer and is mathematically equivalent to the DPO implicit reward when properly initialized.
The PPO algorithm flow:
# Pseudocode for PPO training step
def ppo_step(response_log_probs, response_rewards, ref_log_probs, epsilon=0.2):
# 1. Compute probability ratio
ratio = torch.exp(response_log_probs - old_log_probs)
# 2. Compute advantages (using reward model scores)
advantages = normalize(response_rewards)
# 3. Compute clipped surrogate loss
unclipped_loss = ratio * advantages
clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
clipped_loss = clipped_ratio * advantages
# 4. Add KL penalty for reference model
kl = response_log_probs - ref_log_probs
kl_loss = -0.1 * kl.mean()
# 5. Total loss
total_loss = -torch.min(unclipped_loss, clipped_loss).mean() + kl_loss
return total_loss
Value function: PPO typically uses a value function (critic) to estimate expected returns and reduce variance. In TRL's implementation, this is handled internally, but understanding it helps with debugging.
Implement a minimal PPO training loop in PyTorch without using TRL. Use a simple environment (like a bandit with known reward distribution) to verify your implementation. Check that the policy improves over time and that the KL divergence from the initial policy stays bounded. Plot the KL divergence over training steps to see how clipping affects the update magnitude.