04. DPO Implementation with TRL

Chapter 4 of 24 · 15 min

The HuggingFace TRL library provides a clean implementation of DPO training. The main class is DPOTrainer, which handles the preference data formatting, reference model management, and gradient updates.

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer

# Load a model that has been through SFT
model = AutoModelForCausalLM.from_pretrained("your-sft-model")
tokenizer = AutoTokenizer.from_pretrained("your-sft-model")

# TRL expects a specific format
dataset = load_dataset("anthropic/hh-rlhf", split="train")

# Format the dataset for DPO
def format_preference(example):
    return {
        "prompt": example["chosen"].split("Human: ")[1].split("Assistant: ")[0] + "Assistant: ",
        "chosen": example["chosen"].split("Assistant: ")[-1],
        "rejected": example["rejected"].split("Assistant: ")[-1]
    }

dataset = dataset.map(format_preference, remove_columns=dataset.column_names)

training_args = TrainingArguments(
    output_dir="./dpo_output",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-6,
    num_train_epochs=3,
    beta=0.1,  # DPO temperature
    remove_unused_columns=False,
    optim="adamw_torch",
)

dpo_trainer = DPOTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
)

dpo_trainer.train()

The beta parameter controls the KL penalty strength. Higher values make the policy stay closer to the reference model; lower values allow more deviation in exchange for potentially higher rewards.

Common failure mode: reference model GPU memory. The reference model runs in float32 for numerical stability during gradient computation. This doubles your GPU memory requirement for the reference forward pass. With an 8B parameter model, expect 48-64GB of VRAM for the reference alone.

# Debugging: check reference model is actually on GPU
print(dpo_trainer.ref_model.device)  # Should be cuda:0
print(dpo_trainer.model.device)      # Should be cuda:0
# If they differ, your reference logits won't match your policy logits

Another common issue: data formatting bugs. The prompt must be tokenized identically for both chosen and rejected responses, but the responses are tokenized separately. If your formatting function has bugs, you'll see NaN losses or the model will fail to learn.

EXERCISE

Take a small dataset (10-20 examples), implement DPO training with DPOTrainer, and monitor the loss curve. Then intentionally corrupt the rejected response to be identical to the chosen response and observe what happens to the loss. Document the loss behavior and what it implies about gradient magnitudes when the training signal is ambiguous.