RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /RLHF, DPO, and PPO
  6. /Ch. 21
RLHF, DPO, and PPO

21. Catastrophic Forgetting

Chapter 21 of 24 · 20 min
KEY INSIGHT

Catastrophic forgetting is not a bug to eliminate but a fundamental trade-off to manage. Every alignment update will slightly degrade some capabilities. The question is whether the degradation is acceptable for the alignment improvements achieved.

Alignment training can cause models to forget capabilities they had before training. This catastrophic forgetting is a fundamental challenge in continued training.

The Forgetting Mechanism

When optimizing for alignment, the model updates weights that may have been important for other tasks:

def measure_capability_forgetting(model, aligned_model, capability_tasks):
    """
    Measure how much the aligned model has forgotten capabilities.
    """
    forgetting_scores = {}
    
    for task in capability_tasks:
        original_performance = evaluate(model, task)
        aligned_performance = evaluate(aligned_model, task)
        
        forgetting = original_performance - aligned_performance
        forgetting_scores[task] = {
            "original": original_performance,
            "aligned": aligned_performance,
            "forgetting": forgetting,
            "relative_loss": forgetting / original_performance
        }
    
    return forgetting_scores

Gradient Interference

Different training objectives compete for model weights:

def analyze_gradient_interference(task1_gradients, task2_gradients):
    """
    Measure how much gradients from different tasks interfere.
    High interference = high forgetting risk.
    """
    # Cosine similarity between gradients
    similarity = torch.nn.functional.cosine_similarity(
        task1_gradients.flatten(),
        task2_gradients.flatten(),
        dim=0
    )
    
    # Negative similarity = interference
    interference = -similarity.item()
    
    return {
        "gradient_similarity": similarity.item(),
        "interference_score": interference,
        "warning": interference > 0.5
    }

Mitigation Strategies

Mixed Batch Training:

def mixed_batch_training(model, alignment_data, capability_data, ratio=0.7):
    """
    Alternate between alignment and capability batches.
    """
    while training:
        # Alignment batch
        alignment_batch = sample(alignment_data, batch_size=32)
        align_loss = compute_alignment_loss(model, alignment_batch)
        
        # Capability preservation batch
        capability_batch = sample(capability_data, batch_size=32)
        capability_loss = compute_capability_loss(model, capability_batch)
        
        # Weighted combination
        total_loss = ratio * align_loss + (1 - ratio) * capability_loss
        total_loss.backward()
        
        optimizer.step()

Replay Buffer:

def replay_buffer_training(model, buffer, new_data, buffer_ratio=0.2):
    """
    Mix new alignment data with replay from capability training.
    """
    combined_batch = []
    
    # Sample from replay buffer (capability-preserving examples)
    replay_samples = buffer.sample(int(batch_size * buffer_ratio))
    combined_batch.extend(replay_samples)
    
    # Add new alignment data
    new_samples = sample(new_data, int(batch_size * (1 - buffer_ratio)))
    combined_batch.extend(new_samples)
    
    loss = compute_alignment_loss(model, combined_batch)
    loss.backward()

EWC: Elastic Weight Consolidation:

def ewc_penalty(model, fisher_matrix, optimal_params, lambda_ewc=1000):
    """
    Penalize changes to important parameters (identified by Fisher matrix).
    """
    penalty = 0.0
    for name, param in model.named_parameters():
        if name in fisher_matrix:
            penalty += lambda_ewc * torch.sum(
                fisher_matrix[name] * (param - optimal_params[name]) ** 2
            )
    return penalty
EXERCISE

Train an alignment model for 5 epochs and measure forgetting on 3 capability tasks (math, coding, factual recall). Plot the tradeoff curve between alignment improvement and capability degradation.

← Chapter 20
Multi-Turn Alignment
Chapter 22 →
Alignment on Consumer GPU