14. Iterated Training
Chapter 14 of 24 · 20 min
Alignment training rarely succeeds in a single pass. Iterated training—cycling between training and evaluation—allows progressive refinement of model behavior.
The Training Loop Architecture
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ SFT on │───▶│ Reward │───▶│ PPO/PPO │
│ 高质量数据 │ │ Training │ │ Training │
└─────────────┘ └─────────────┘ └─────────────┘
▲ │
│ ┌─────────────┐ │
└─────────│ Evaluation │◀────────────┘
│ & Filter │
└─────────────┘
Iteration Monitoring
Track metrics across iterations to detect degradation:
class TrainingIterator:
def __init__(self, base_model, config):
self.model = base_model
self.iteration = 0
self.metrics_history = []
def step(self, training_data):
# Train for one iteration
self.model = sft_train(self.model, training_data)
self.model = reward_train(self.model, training_data)
self.model = ppo_train(self.model, training_data)
# Evaluate
metrics = evaluate_alignment(self.model)
self.metrics_history.append(metrics)
# Check for divergence
if self.detect_degradation():
print(f"WARNING: Degradation detected at iteration {self.iteration}")
# Trigger rollback or intervention
self.iteration += 1
return self.model
def detect_degradation(self):
if len(self.metrics_history) < 3:
return False
# Check reward model accuracy trend
recent = self.metrics_history[-3:]
if all(m["reward_accuracy"] < 0.6 for m in recent):
return True
# Check for capability regression
if self.metrics_history[-1]["task_accuracy"] < self.metrics_history[-3]["task_accuracy"] - 0.05:
return True
return False
Early Stopping Criteria
Not all divergence indicates failure—some iterations improve alignment without improving capabilities:
def should_continue_training(iteration, metrics):
# Stop if alignment plateaus
if metrics["alignment_score"] > 0.95:
return False, "Alignment target reached"
# Stop if capabilities degrade significantly
if metrics["task_accuracy"] < 0.70:
return False, "Capability regression"
# Stop if training becomes unstable
if metrics["reward_var"] > 2.0:
return False, "Training instability"
# Stop if diminishing returns
if len(metrics_history) > 5:
recent_improvement = metrics_history[-1]["alignment"] - metrics_history[-5]["alignment"]
if recent_improvement < 0.01:
return False, "Diminishing returns"
return True, "Continue training"
Data Reweighting Across Iterations
Later iterations should weight data differently as the model matures:
def compute_iteration_weights(examples, iteration):
base_weights = compute_quality_weights(examples)
if iteration < 3:
# Early iterations: focus on basic safety
safety_multiplier = 2.0
elif iteration < 6:
# Middle iterations: balance safety and helpfulness
safety_multiplier = 1.0
else:
# Late iterations: emphasize nuanced responses
helpfulness_multiplier = 1.5
# Apply iteration-specific adjustments
for ex in examples:
if ex["safety_critical"]:
ex["weight"] *= safety_multiplier
if ex["nuanced"]:
ex["weight"] *= helpfulness_multiplier
return examples
EXERCISE
Implement a simple iterated training loop that trains for 3 iterations, monitoring reward model accuracy and task performance. Visualize how metrics change across iterations.