09. Distillation of Reasoning
Distilling reasoning capabilities into smaller, faster models enables cost-effective deployment. The challenge is preserving reasoning quality while reducing model size. R1's training methodology offers insights into effective distillation approaches.
Why Reasoning Distillation is Hard
Standard model distillation transfers capabilities through behavioral cloning—smaller model learns to match larger model's outputs. Reasoning is challenging because the reasoning chain itself matters, not just the final answer.
A model that matches R1's answers but not its reasoning may fail when presented with novel problems. The reasoning process is the capability; the answers are just a symptom.
RL-Based Distillation
R1 was trained using GRPO (Group Relative Policy Optimization), which suggests a distillation approach:
# RL-based distillation pipeline
class ReasoningDistillation:
def __init__(self, teacher_model, student_model, reward_model):
self.teacher = teacher_model
self.student = student_model
self.reward = reward_model
def distill(self, problems, iterations=1000):
"""Distill reasoning from teacher to student using RL"""
for problem in problems:
# Generate reasoning chains from teacher
teacher_chains = self.teacher.generate(
problem,
num_samples=8,
return_reasoning=True
)
# Score chains with reward model
chain_scores = [self.reward.score(c) for c in teacher_chains]
# Train student to generate high-scoring chains
for _ in range(iterations):
student_chain = self.student.generate(problem)
score = self.reward.score(student_chain)
# Use reward signal for policy gradient update
self.student.update(score)
def evaluate_distillation(self, test_problems):
"""Measure student quality vs teacher"""
results = []
for problem in test_problems:
teacher_answer = self.teacher.solve(problem)
student_answer = self.student.solve(problem)
results.append(teacher_answer == student_answer)
return sum(results) / len(results)
Distill Dataset Composition
Effective distillation requires diverse training data:
- Math problems: Use GSM8K, MATH, AIME datasets for structured reasoning
- Code generation: Use HumanEval, MBPP for algorithmic reasoning
- Logical deduction: Use ARC-Challenge, LogiQA for multi-step logic
- General reasoning: Use BigBench Hard for diverse problem types
The mixture matters—too much math makes the model over-specialize, too little fails to transfer reasoning capability.
Quantization-Aware Distillation
A refined approach is quantization-aware distillation: during distillation, the student model is quantized, and the training process adapts to quantization noise.
# Quantization-aware distillation
def quant_aware_distill(teacher, student, data, quant_bits=4):
# Apply quantization to student
quantized_student = quantize(student, bits=quant_bits)
# Fine-tune with quantization noise regularization
for batch in data:
# Add noise simulating quantization error
noisy_student = add_quant_noise(quantized_student)
teacher_output = teacher(batch)
student_output = noisy_student(batch)
loss = kl_divergence(teacher_output, student_output)
# Loss includes adaptation to quantization artifacts
Evaluating Distilled Models
Metrics for distilled reasoning quality:
def evaluate_reasoning_quality(student, teacher, test_set):
"""Thorough evaluation of distilled model"""
metrics = {}
# Answer accuracy
metrics["accuracy"] = compute_accuracy(student, teacher, test_set)
# Reasoning chain similarity (if visible)
student_chains = student.generate_chain(test_set)
teacher_chains = teacher.generate_chain(test_set)
metrics["chain_similarity"] = compute_chain_similarity(
student_chains, teacher_chains
)
# Calibration: does student know when it's uncertain?
metrics["calibration"] = compute_calibration(student, test_set)
# Out-of-distribution reliability
ood_set = generate_ood_problems(test_set)
metrics["ood_accuracy"] = compute_accuracy(student, teacher, ood_set)
return metrics
Deployment Considerations
Distilled models work well for:
- Bounding boxes on easy problems (fast path)
- Pre-filtering complex problems before routing to full R1
- Development and testing environments
They struggle with:
- Novel problem types not in distillation data
- Multi-step reasoning requiring >5 intermediate steps
- High-stakes applications where failure cost is high
Select a distilled R1 variant (e.g., R1-Distill-Qwen). Run it against your hardest reasoning tasks. Identify failure modes and categorize whether they stem from capacity limitations or distribution mismatch.