17. Training Stability
Chapter 17 of 24 · 25 min
Training instability is the primary failure mode for custom architectures. Gradient explosions, loss spikes, and NaN values can destroy weeks of compute investment.
Gradient Clipping Implementation
import torch
import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
class GradientController:
"""
Manages gradient dynamics for stable training.
Custom architectures often need stricter clipping than standard.
"""
def __init__(self, max_norm=1.0, clip_cooldown=100):
self.max_norm = max_norm
self.clip_count = 0
self.clip_cooldown = clip_cooldown
self.last_clip_time = 0
def clip_if_needed(self, model, step):
"""Clip gradients and track clipping frequency"""
total_norm = clip_grad_norm_(model.parameters(), self.max_norm)
if total_norm > self.max_norm:
self.clip_count += 1
# Alert if clipping frequency is too high
if step - self.last_clip_time > self.clip_cooldown:
if self.clip_count > 10:
print(f"WARNING: {self.clip_count} clips in last {self.clip_cooldown} steps")
self.clip_count = 0
self.last_clip_time = step
return total_norm
# Usage in training loop
controller = GradientController(max_norm=0.5) # Stricter for custom arch
for step, batch in enumerate(dataloader):
outputs = model(batch)
loss = compute_loss(outputs, batch)
optimizer.zero_grad()
loss.backward()
# Check before clipping
total_norm = sum(p.grad.norm() for p in model.parameters() if p.grad is not None)
if total_norm > 100:
print(f"Gradient explosion at step {step}: norm={total_norm:.2f}")
# Clip and step
controller.clip_if_needed(model, step)
optimizer.step()
Learning Rate Warmup for Custom Architectures
def get_lr_with_warmup(step, warmup_steps, base_lr, min_lr=1e-6, schedule='linear'):
"""
Learning rate schedule with warmup.
Custom architectures often need longer warmup.
"""
if step < warmup_steps:
# Linear warmup
return base_lr * step / warmup_steps
# Post-warmup schedules
progress = (step - warmup_steps) / (total_steps - warmup_steps)
if schedule == 'linear':
return base_lr - (base_lr - min_lr) * progress
elif schedule == 'cosine':
return min_lr + 0.5 * (base_lr - min_lr) * (1 + math.cos(math.pi * progress))
elif schedule == 'constant':
return base_lr
else:
raise ValueError(f"Unknown schedule: {schedule}")
class CustomArchitectureTrainer:
def __init__(self, model, optimizer, warmup_steps=2000, base_lr=1e-4):
self.model = model
self.optimizer = optimizer
self.warmup_steps = warmup_steps
self.base_lr = base_lr
self.step = 0
def step(self, batch):
self.step += 1
lr = get_lr_with_warmup(self.step, self.warmup_steps, self.base_lr)
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
outputs = self.model(batch)
loss = outputs['loss']
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
self.optimizer.zero_grad()
return {'loss': loss.item(), 'lr': lr}
Mixed Precision Training
from torch.cuda.amp import autocast, GradScaler
class MixedPrecisionTrainer:
"""
Training with bf16 mixed precision for custom architectures.
bf16 has same range as fp32 but less precision.
Large custom architectures especially benefit from memory savings.
"""
def __init__(self, model, optimizer, base_lr=1e-4):
self.model = model
self.optimizer = optimizer
self.base_lr = base_lr
self.scaler = GradScaler()
self.step = 0
def step(self, batch):
self.step += 1
lr = get_lr_with_warmup(self.step, 2000, self.base_lr)
for pg in self.optimizer.param_groups:
pg['lr'] = lr
# Forward with autocast
with autocast(dtype=torch.bfloat16):
outputs = self.model(batch)
loss = outputs['loss']
# Check for NaN
if torch.isnan(loss):
# Save checkpoint for debugging
torch.save({
'step': self.step,
'model_state': self.model.state_dict(),
'batch': batch
}, f'nan_checkpoint_step{self.step}.pt')
raise ValueError(f"NaN loss at step {self.step}")
# Backward with scaled loss
self.scaler.scale(loss).backward()
# Unscale before clipping
self.scaler.unscale_(self.optimizer)
clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
return loss.item()
Initialization Strategies
class StableMLP(nn.Module):
"""
MLP block with careful initialization for training stability.
"""
def __init__(self, d_model, d_ff, init_method='kaiming'):
super().__init__()
self.up = nn.Linear(d_model, d_ff)
self.down = nn.Linear(d_ff, d_model)
# Careful initialization prevents gradient issues
if init_method == 'kaiming':
nn.init.kaiming_normal_(self.up.weight, nonlinearity='relu')
nn.init.zeros_(self.down.weight)
nn.init.zeros_(self.up.bias)
nn.init.zeros_(self.down.bias)
elif init_method == 'xavier':
nn.init.xavier_normal_(self.up.weight)
nn.init.xavier_normal_(self.down.weight)
elif init_method == 'small':
nn.init.normal_(self.up.weight, std=0.02)
nn.init.normal_(self.down.weight, std=0.02)
def forward(self, x):
return self.down(F.silu(self.up(x)))
# Custom attention initialization
class StableAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.proj = nn.Linear(d_model, d_model, bias=False)
# Scale initial weights to prevent attention saturation
nn.init.normal_(self.qkv.weight, std=0.02)
nn.init.zeros_(self.proj.weight)
def forward(self, x):
q, k, v = self.qkv(x).chunk(3, dim=-1)
# ... attention computation ...
return self.proj(out)
Failure Mode: Embedding Norm Explosion
# BUG: Unbounded embedding values lead to NaN in attention
class UnboundedEmbedding(nn.Module):
def __init__(self, vocab_size, d_model):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
# No initialization - defaults to N(0, 1)
# With vocab_size=100k, rare tokens can have large norms
def forward(self, tokens):
return self.embedding(tokens)
# tokens with rare IDs produce vectors with norm ~sqrt(d_model)
# With d_model=4096, norms can be ~64
# Stacks of layers amplify this quickly
# FIX: Bound the embedding norm
class BoundedEmbedding(nn.Module):
def __init__(self, vocab_size, d_model, max_norm=1.0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model, max_norm=max_norm)
nn.init.normal_(self.embedding.weight, std=0.02)
def forward(self, tokens):
# F.normalize ensures norm <= max_norm
return F.normalize(self.embedding(tokens), p=2, dim=-1) * max_norm
EXERCISE
Implement a training loop that simulates gradient explosion by artificially multiplying gradients by 10 at step 500. Verify that your stability mechanisms catch and recover from this without losing earlier progress.