11. PPO Implementation
Chapter 11 of 24 · 20 min
The TRL library provides a production-ready PPO implementation via PPOTrainer. It handles the complexity of running three models (policy, reference, reward) in coordination.
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import PPOTrainer, PPOConfig
import torch
# Load models
model = AutoModelForCausalLM.from_pretrained("your-sft-model")
ref_model = AutoModelForCausalLM.from_pretrained("your-sft-model")
reward_model = AutoModelForCausalLM.from_pretrained("your-reward-model")
tokenizer = AutoTokenizer.from_pretrained("your-model-name")
tokenizer.pad_token = tokenizer.eos_token
# Configure PPO
config = PPOConfig(
model_name="your-model",
learning_rate=1e-6,
ppo_epochs=4,
mini_batch_size=4,
batch_size=16,
gradient_accumulation_steps=1,
adalambda_lr=1e-4,
kl_penalty="kl", # Options: "kl", "abs", "mse"
)
# Initialize trainer
ppo_trainer = PPOTrainer(
config=config,
model=model,
ref_model=ref_model,
reward_model=reward_model,
tokenizer=tokenizer,
)
# Generate responses and compute rewards
def collate_fn(examples):
"""Custom collate function for batching prompts."""
prompts = [e["prompt"] for e in examples]
query_tensors = tokenizer(prompts, padding=True, return_tensors="pt")["input_ids"]
return query_tensors
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
# Generate responses
query_tensors = batch["query_tensors"].to(model.device)
response_tensors = ppo_trainer.generate(
query_tensors,
return_prompt=False,
max_new_tokens=128,
)
# Compute rewards
responses = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
queries = tokenizer.batch_decode(query_tensors, skip_special_tokens=True)
# Get reward model scores
reward_inputs = tokenizer(
[q + r for q, r in zip(queries, responses)],
padding=True,
return_tensors="pt"
).to(reward_model.device)
with torch.no_grad():
rewards = reward_model(**reward_inputs).squeeze(-1)
# PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
Memory optimization with 8-bit reward models:
# For large models, use 8-bit loading to reduce memory
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
)
reward_model = AutoModelForCausalLM.from_pretrained(
"your-reward-model",
quantization_config=quantization_config,
device_map="auto"
)
Common failure: reward model format mismatch. The reward model expects (query, response) pairs, but if you concatenate them with the wrong separator or miss special tokens, the reward will be garbage.
# Debug: verify reward model input format
test_input = "Human: What is Python?\nAssistant: Python is a programming language."
print("Input length:", len(tokenizer.encode(test_input)))
# Check reward range
with torch.no_grad():
test_reward = reward_model(**tokenizer(test_input, return_tensors="pt").to(device))
print("Test reward:", test_reward.item())
# If reward is NaN or near-zero, check tokenization and model loading
EXERCISE
Run a complete PPO training run on a small model (1-3B parameters) using the HH-RLHF dataset. Monitor and log: KL divergence over time, reward distribution changes, policy perplexity changes. Identify the point of diminishing returns (where additional training stops improving reward) and document the total compute required.