22. Training Infrastructure
Chapter 22 of 24 · 30 min
Custom architectures demand custom infrastructure. Standard training pipelines often fail to handle novel components efficiently.
Distributed Training Setup
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
def setup_distributed():
"""Initialize distributed training environment"""
dist.init_process_group(backend='nccl')
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
return local_rank
def create_fsdp_model(model, local_rank):
"""
Wrap model with FSDP for distributed training.
Essential for models that don't fit on single GPU.
"""
# Define auto-wrap policy for transformer layers
def auto_wrap_policy(module, recurse, **kwargs):
if hasattr(module, '_fsdp_wrap') or 'TransformerBlock' in module.__class__.__name__:
return True
return recurse
model = model.to(local_rank)
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
auto_wrap_policy=auto_wrap_policy,
device_id=local_rank,
mixed_precision=True,
backward_prefetch=torch.distributed.fsdp.BackwardPrefetch.BACKWARD_PRE,
forward_prefetch=True,
)
return fsdp_model
class CustomArchitectureTrainer:
"""
Training infrastructure for custom architectures.
Handles distributed training, checkpointing, and recovery.
"""
def __init__(self, model_config: Dict, training_config: Dict):
self.local_rank = setup_distributed()
self.world_size = dist.get_world_size()
self.model = self.build_model(model_config)
self.fsdp_model = create_fsdp_model(self.model, self.local_rank)
self.optimizer = self.create_optimizer()
self.scaler = GradScaler()
self.checkpoint_dir = training_config['checkpoint_dir']
self.save_interval = training_config.get('save_interval', 1000)
# Resume from checkpoint if exists
self.resume_if_available()
def build_model(self, config):
"""Build custom architecture from config"""
from your_module import CustomTransformer
model = CustomTransformer(
d_model=config['d_model'],
n_layers=config['n_layers'],
n_heads=config['n_heads'],
vocab_size=config.get('vocab_size', 32000),
)
return model
def create_optimizer(self):
"""Create optimizer with layer-wise LR decay"""
return torch.optim.AdamW(
self.fsdp_model.parameters(),
lr=1e-4,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1
)
def resume_if_available(self):
"""Resume training from latest checkpoint"""
import glob
checkpoints = glob.glob(f"{self.checkpoint_dir}/checkpoint_*.pt")
if not checkpoints:
return
latest = sorted(checkpoints)[-1]
if self.local_rank == 0:
print(f"Resuming from {latest}")
checkpoint = torch.load(latest, map_location='cuda')
self.fsdp_model.load_state_dict(checkpoint['model'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.scaler.load_state_dict(checkpoint['scaler'])
self.step = checkpoint['step']
self.global_step = checkpoint['global_step']
Training Loop with Custom Components
class TrainingLoop:
"""
Production training loop with custom architecture support.
"""
def __init__(self, model, config):
self.model = model
self.config = config
self.step = 0
# Custom components
self.grad_monitor = GradientMonitor()
self.loss_tracker = LossTracker()
self.tensorboard_writer = SummaryWriter(config['log_dir'])
def train_step(self, batch):
"""Single training step"""
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
# Mixed precision forward pass
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
loss = outputs['loss']
# Check for NaN
if torch.isnan(loss):
self.handle_nan_loss(batch, loss)
return None
# Backward pass with gradient scaling
self.scaler.scale(loss).backward()
# Monitor gradients
grad_norm = self.grad_monitor.check(self.model)
if grad_norm > self.config.get('max_grad_norm', 1.0):
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config['max_grad_norm']
)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# Log metrics
self.log_metrics(loss, grad_norm)
self.step += 1
return {'loss': loss.item(), 'grad_norm': grad_norm}
def handle_nan_loss(self, batch, loss):
"""Handle NaN loss with diagnostic and recovery"""
if self.local_rank == 0:
print(f"NaN detected at step {self.step}")
# Save diagnostic checkpoint
torch.save({
'step': self.step,
'loss': loss.item(),
'batch': {k: v.cpu() for k, v in batch.items()},
'model_state': self.model.state_dict(),
}, f"{self.checkpoint_dir}/nan_diagnostic_{self.step}.pt")
# Analyze gradients
self.grad_monitor.analyze(self.model)
def log_metrics(self, loss, grad_norm):
"""Log training metrics"""
if self.local_rank == 0:
self.tensorboard_writer.add_scalar('loss', loss.item(), self.step)
self.tensorboard_writer.add_scalar('grad_norm', grad_norm, self.step)
self.tensorboard_writer.add_scalar('lr', self.get_lr(), self.step)
def get_lr(self):
for param_group in self.optimizer.param_groups:
return param_group['lr']
class GradientMonitor:
"""Monitor gradient health for custom architectures"""
def __init__(self):
self.history = []
self.alert_threshold = 10.0
def check(self, model):
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
self.history.append(total_norm)
if total_norm > self.alert_threshold:
print(f"WARNING: Large gradient norm: {total_norm:.2f}")
return total_norm
def analyze(self, model):
"""Analyze gradient distribution across layers"""
for name, param in model.named_parameters():
if param.grad is not None:
grad = param.grad.data
print(f"{name}: mean={grad.mean():.6f}, std={grad.std():.6f}, "
f"max={grad.abs().max():.4f}")
Checkpoint Management
import os
import glob
class CheckpointManager:
"""
Manage model checkpoints with async saving.
Critical for long training runs with custom architectures.
"""
def __init__(self, save_dir, max_checkpoints=5):
self.save_dir = save_dir
self.max_checkpoints = max_checkpoints
os.makedirs(save_dir, exist_ok=True)
def save_checkpoint(self, model, optimizer, scaler, step, metadata=None):
"""Save checkpoint with model sharding for FSDP"""
if dist.get_rank() == 0: # Only save from rank 0
checkpoint_path = f"{self.save_dir}/checkpoint_{step}.pt"
checkpoint = {
'step': step,
'metadata': metadata or {},
}
# For FSDP, need to gather state dict
# This is simplified - production code needs proper handling
if hasattr(model, 'module'):
# DDP/FSDP wrapped
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
checkpoint['model'] = model_state
checkpoint['optimizer'] = optimizer.state_dict()
checkpoint['scaler'] = scaler.state_dict()
torch.save(checkpoint, checkpoint_path)
# Clean up old checkpoints
self._cleanup_old_checkpoints()
return checkpoint_path
return None
def _cleanup_old_checkpoints(self):
"""Remove old checkpoints, keeping max_checkpoints most recent"""
checkpoints = glob.glob(f"{self.save_dir}/checkpoint_*.pt")
if len(checkpoints) > self.max_checkpoints:
checkpoints.sort(key=os.path.getmtime)
for old in checkpoints[:-self.max_checkpoints]:
os.remove(old)
def load_checkpoint(self, path, model, optimizer=None, scaler=None):
"""Load checkpoint"""
checkpoint = torch.load(path, map_location='cuda')
if hasattr(model, 'module'):
model.module.load_state_dict(checkpoint['model'])
else:
model.load_state_dict(checkpoint['model'])
if optimizer and 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
if scaler and 'scaler' in checkpoint:
scaler.load_state_dict(checkpoint['scaler'])
return checkpoint['step']
Failure Mode: Memory Fragmentation
# BUG: Creating tensors without proper cleanup causes OOM
class MemoryLeakTrainer:
def train_step(self, batch):
# Creating new tensors in loop without deleting
for i in range(100):
intermediate = torch.randn(1000, 1000, device='cuda')
# ... compute ...
# intermediate not deleted, memory accumulates
# After 100 steps, OOM
# FIX: Proper memory management
class MemorySafeTrainer:
def train_step(self, batch):
intermediate = None
try:
for i in range(100):
intermediate = torch.randn(1000, 1000, device='cuda')
# ... compute ...
finally:
del intermediate
torch.cuda.empty_cache()
# Or use factory method to reuse buffer
buffer = torch.zeros(1000, 1000, device='cuda')
for i in range(100):
buffer.random_() # Reuse existing allocation
EXERCISE
Implement a training loop that handles custom attention patterns with gradient checkpointing. Benchmark the memory savings and throughput impact of different gradient checkpointing strategies.