10. PPO with KL Control
KL control is the mechanism that prevents PPO from diverging from the reference policy. It balances two competing objectives: maximizing reward (which might push the policy to degenerate solutions) and staying close to the reference (which preserves the model's original capabilities).
The KL penalty term in the objective has the form β * KL[π || π_ref], where β is a hyperparameter controlling the strength of the penalty.
def compute_kl_penalty(log_probs, ref_log_probs, beta=0.1):
"""
Compute KL penalty for staying close to reference.
Uses the log-ratio formulation for numerical stability.
"""
# KL divergence between current policy and reference
# KL(P||Q) = sum(P * log(P/Q))
# In terms of log probs: KL = E[log_probs - log_ref_probs]
kl = torch.exp(log_probs - ref_log_probs) * (log_probs - ref_log_probs)
return beta * kl.mean()
Adaptive KL control: Static β is often suboptimal. If the policy is too far from the reference, you want a stronger penalty; if it's close, you can afford more aggressive reward seeking. Adaptive KL control adjusts β during training.
class AdaptiveKLController:
def __init__(self, init_kl_coef=0.1, target_kl=0.1, horizon=100):
self.kl_coef = init_kl_coef
self.target_kl = target_kl
self.horizon = horizon
self.step_count = 0
def update(self, current_kl):
"""Update KL coefficient based on observed KL divergence."""
self.step_count += 1
# Exponential moving average smoothing
error = current_kl - self.target_kl
proportional = 0.1 * error
# Adjust coefficient
self.kl_coef *= torch.exp(proportional)
self.kl_coef = torch.clamp(self.kl_coef, 1e-4, 1e1)
return self.kl_coef
KL divergence as a training diagnostic:
# Monitor KL during training
kl_estimates = []
for batch in dataloader:
with torch.no_grad():
# Sample from current policy
current_log_probs = policy(inputs).log_probs
ref_log_probs = ref_policy(inputs).log_probs
kl = (torch.exp(current_log_probs - ref_log_probs) *
(current_log_probs - ref_log_probs)).mean()
kl_estimates.append(kl.item())
# Check if KL is converging
avg_kl = np.mean(kl_estimates[-100:]) # Last 100 batches
if avg_kl > 2 * target_kl:
print("WARNING: KL divergence too high, policy drifting")
elif avg_kl < 0.5 * target_kl:
print("WARNING: KL divergence too low, not learning enough")
Target KL selection: The target KL should be small enough to prevent capability loss but large enough to allow meaningful improvement. Typical values:
- 0.01-0.03: Conservative, preserves capabilities at cost of slower alignment
- 0.05-0.10: Moderate, balanced
- 0.10-0.20: Aggressive, risk of capability loss but faster alignment
Implement the adaptive KL controller and run it alongside a static KL controller. Compare the training curves: reward over time, KL divergence over time, and capability metrics (perplexity on held-out text). The adaptive controller should achieve similar or better reward while maintaining KL closer to target.