21. Catastrophic Forgetting
Chapter 21 of 24 · 20 min
Alignment training can cause models to forget capabilities they had before training. This catastrophic forgetting is a fundamental challenge in continued training.
The Forgetting Mechanism
When optimizing for alignment, the model updates weights that may have been important for other tasks:
def measure_capability_forgetting(model, aligned_model, capability_tasks):
"""
Measure how much the aligned model has forgotten capabilities.
"""
forgetting_scores = {}
for task in capability_tasks:
original_performance = evaluate(model, task)
aligned_performance = evaluate(aligned_model, task)
forgetting = original_performance - aligned_performance
forgetting_scores[task] = {
"original": original_performance,
"aligned": aligned_performance,
"forgetting": forgetting,
"relative_loss": forgetting / original_performance
}
return forgetting_scores
Gradient Interference
Different training objectives compete for model weights:
def analyze_gradient_interference(task1_gradients, task2_gradients):
"""
Measure how much gradients from different tasks interfere.
High interference = high forgetting risk.
"""
# Cosine similarity between gradients
similarity = torch.nn.functional.cosine_similarity(
task1_gradients.flatten(),
task2_gradients.flatten(),
dim=0
)
# Negative similarity = interference
interference = -similarity.item()
return {
"gradient_similarity": similarity.item(),
"interference_score": interference,
"warning": interference > 0.5
}
Mitigation Strategies
Mixed Batch Training:
def mixed_batch_training(model, alignment_data, capability_data, ratio=0.7):
"""
Alternate between alignment and capability batches.
"""
while training:
# Alignment batch
alignment_batch = sample(alignment_data, batch_size=32)
align_loss = compute_alignment_loss(model, alignment_batch)
# Capability preservation batch
capability_batch = sample(capability_data, batch_size=32)
capability_loss = compute_capability_loss(model, capability_batch)
# Weighted combination
total_loss = ratio * align_loss + (1 - ratio) * capability_loss
total_loss.backward()
optimizer.step()
Replay Buffer:
def replay_buffer_training(model, buffer, new_data, buffer_ratio=0.2):
"""
Mix new alignment data with replay from capability training.
"""
combined_batch = []
# Sample from replay buffer (capability-preserving examples)
replay_samples = buffer.sample(int(batch_size * buffer_ratio))
combined_batch.extend(replay_samples)
# Add new alignment data
new_samples = sample(new_data, int(batch_size * (1 - buffer_ratio)))
combined_batch.extend(new_samples)
loss = compute_alignment_loss(model, combined_batch)
loss.backward()
EWC: Elastic Weight Consolidation:
def ewc_penalty(model, fisher_matrix, optimal_params, lambda_ewc=1000):
"""
Penalize changes to important parameters (identified by Fisher matrix).
"""
penalty = 0.0
for name, param in model.named_parameters():
if name in fisher_matrix:
penalty += lambda_ewc * torch.sum(
fisher_matrix[name] * (param - optimal_params[name]) ** 2
)
return penalty
EXERCISE
Train an alignment model for 5 epochs and measure forgetting on 3 capability tasks (math, coding, factual recall). Plot the tradeoff curve between alignment improvement and capability degradation.