19. ORPO
Chapter 19 of 24 · 20 min
Odds Ratio Policy Optimization (ORPO) is a newer alignment technique that unifies the preference learning and style control into a single training step, eliminating the need for a separate reference model.
The Odds Ratio Foundation
ORPO uses odds ratios to measure preference:
def compute_odds_ratio(logps_chosen, logps_rejected):
"""
Compute odds ratio for preference.
Odds ratio > 1 means chosen is more likely than rejected.
"""
# Convert log probs to probabilities
prob_chosen = torch.exp(logps_chosen)
prob_rejected = torch.exp(logps_rejected)
# Odds for each response
odds_chosen = prob_chosen / (1 - prob_chosen + 1e-8)
odds_rejected = prob_rejected / (1 - prob_rejected + 1e-8)
# Odds ratio
return odds_chosen / (odds_rejected + 1e-8)
def odds_ratio_loss(logps_chosen, logps_rejected, beta=0.5):
"""
ORPO loss based on odds ratio.
"""
# Preference loss: push odds ratio higher
ratio = compute_odds_ratio(logps_chosen, logps_rejected)
preference_loss = -torch.log(ratio / (1 + ratio))
# Implicity style control through divergence term
# ORPO does NOT use a separate reference model for KL
style_loss = beta * (logps_rejected - logps_chosen) # Implicit regularization
return preference_loss + style_loss
ORPO Training Implementation
def orpo_training_step(model, batch, optimizer):
"""Single ORPO training step."""
prompt = batch["prompt"]
chosen = batch["chosen"]
rejected = batch["rejected"]
# Get log probabilities
logps_chosen = model(prompt, chosen).log_probs
logps_rejected = model(prompt, rejected).log_probs
# Compute ORPO loss
loss = odds_ratio_loss(logps_chosen, logps_rejected)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Compute metrics
with torch.no_grad():
accuracy = (logps_chosen > logps_rejected).float().mean()
odds_ratio = compute_odds_ratio(logps_chosen, logps_rejected).mean()
return {"loss": loss.item(), "accuracy": accuracy.item(), "odds_ratio": odds_ratio.item()}
Why ORPO Eliminates the Reference Model
DPO requires a reference model to compute the KL divergence penalty:
# DPO needs both policy and reference
def dpo_needs_reference(policy_logps, reference_logps, chosen, rejected):
# KL penalty requires reference
kl = reference_logps(chosen) - reference_logps(rejected)
return kl
# ORPO avoids this by being self-regularizing
def orpo_is_self_regularizing(logps_chosen, logps_rejected):
# The style loss term implicitly regularizes
# No reference model needed
style_loss = logps_rejected - logps_chosen # Penalizes low probability on chosen
return style_loss
Empirical Comparison
| Aspect | DPO | ORPO |
|---|---|---|
| Reference model | Required | Not needed |
| Training steps | 2-phase (ref + policy) | Single-phase |
| Memory usage | 2x (policy + reference) | 1x |
| Convergence speed | Slower | Faster |
| Hyperparameter sensitivity | Moderate | Low |
EXERCISE
Implement ORPO training and compare it to DPO on a small alignment task. Measure training time, memory usage, and final preference accuracy.