18. RRHF and IPO
Chapter 18 of 24 · 20 min
RRHF (Rank Responses from Human Feedback) and IPO (Identity Preference Optimization) are alternative approaches to alignment that avoid the complexity of PPO while maintaining effective preference learning.
RRHF: Score-Based Ranking
RRHF trains a model to score responses like a reward model:
def rrhf_loss(model, prompt, responses, reward_model):
"""
RRHF loss: train model to rank responses the same as reward model.
"""
scores = [reward_model(prompt, resp) for resp in responses]
# Sort by reward
sorted_pairs = sorted(zip(responses, scores), key=lambda x: x[1], reverse=True)
sorted_responses = [r for r, s in sorted_pairs]
# Create preference pairs from ranking
total_loss = 0.0
for i in range(len(sorted_responses)):
for j in range(i + 1, len(sorted_responses)):
# Positive: higher-ranked response
# Negative: lower-ranked response
pos_logits = model(prompt, sorted_responses[i])
neg_logits = model(prompt, sorted_responses[j])
# Softmax ranking loss
loss = -torch.log_softmax([pos_logits, neg_logits], dim=0)[0]
total_loss += loss
return total_loss / (len(sorted_responses) ** 2)
IPO: Regularized Preference Optimization
IPO adds an explicit regularization term to prevent model collapse:
def ipo_loss(model, prompt, chosen, rejected, beta=0.1):
"""
Identity Preference Optimization.
Directly optimizes that chosen > rejected without KL penalty.
"""
chosen_logps = model.log_prob(prompt, chosen)
rejected_logps = model.log_prob(prompt, rejected)
# Simple pairwise loss with stronger regularization
# The beta parameter controls the margin
loss = -torch.log(torch.sigmoid(chosen_logps - rejected_logps - beta))
return loss.mean()
def dpo_loss(model, prompt, chosen, rejected, beta=0.1):
"""
Direct Preference Optimization (DPO) for comparison.
"""
policy_logps = model.log_prob(prompt, chosen) - model.log_prob(prompt, rejected)
reference_logps = model.reference_log_prob(prompt, chosen) - model.reference_log_prob(prompt, rejected)
# DPO has implicit regularization through reference
loss = -torch.log(torch.sigmoid(
beta * (policy_logps - reference_logps)
))
return loss.mean()
Comparing DPO and IPO Convergence
def compare_convergence():
"""Compare DPO vs IPO on a simple task."""
model = load_base_model()
# Track divergence from reference
dpo_divergences = []
ipo_divergences = []
for step in range(1000):
batch = sample_preference_batch()
# DPO update
dpo_loss_value = dpo_loss(model, **batch)
dpo_loss_value.backward()
# Track divergence
kl = compute_kl_divergence(model, model.reference)
dpo_divergences.append(kl)
# IPO update
ipo_loss_value = ipo_loss(model, **batch)
ipo_loss_value.backward()
kl = compute_kl_divergence(model, model.reference)
ipo_divergences.append(kl)
plot_convergence(dpo_divergences, ipo_divergences)
EXERCISE
Implement both DPO and IPO loss functions and compare their behavior on a small dataset. Measure both preference accuracy and KL divergence from the reference model.