RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /Custom LLM Architecture Design
  6. /Ch. 22
Custom LLM Architecture Design

22. Training Infrastructure

Chapter 22 of 24 · 30 min
KEY INSIGHT

Custom architectures require custom infrastructure. FSDP handles sharding, but custom attention patterns may need specialized kernels. Invest in gradient monitoring, checkpoint management, and training diagnostics from day one—a 100B parameter training run represents weeks of compute and millions of dollars.

Custom architectures demand custom infrastructure. Standard training pipelines often fail to handle novel components efficiently.

Distributed Training Setup

import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

def setup_distributed():
    """Initialize distributed training environment"""
    dist.init_process_group(backend='nccl')
    
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    
    return local_rank

def create_fsdp_model(model, local_rank):
    """
    Wrap model with FSDP for distributed training.
    Essential for models that don't fit on single GPU.
    """
    # Define auto-wrap policy for transformer layers
    def auto_wrap_policy(module, recurse, **kwargs):
        if hasattr(module, '_fsdp_wrap') or 'TransformerBlock' in module.__class__.__name__:
            return True
        return recurse
    
    model = model.to(local_rank)
    
    fsdp_model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        auto_wrap_policy=auto_wrap_policy,
        device_id=local_rank,
        mixed_precision=True,
        backward_prefetch=torch.distributed.fsdp.BackwardPrefetch.BACKWARD_PRE,
        forward_prefetch=True,
    )
    
    return fsdp_model

class CustomArchitectureTrainer:
    """
    Training infrastructure for custom architectures.
    Handles distributed training, checkpointing, and recovery.
    """
    def __init__(self, model_config: Dict, training_config: Dict):
        self.local_rank = setup_distributed()
        self.world_size = dist.get_world_size()
        
        self.model = self.build_model(model_config)
        self.fsdp_model = create_fsdp_model(self.model, self.local_rank)
        
        self.optimizer = self.create_optimizer()
        self.scaler = GradScaler()
        
        self.checkpoint_dir = training_config['checkpoint_dir']
        self.save_interval = training_config.get('save_interval', 1000)
        
        # Resume from checkpoint if exists
        self.resume_if_available()
    
    def build_model(self, config):
        """Build custom architecture from config"""
        from your_module import CustomTransformer
        
        model = CustomTransformer(
            d_model=config['d_model'],
            n_layers=config['n_layers'],
            n_heads=config['n_heads'],
            vocab_size=config.get('vocab_size', 32000),
        )
        return model
    
    def create_optimizer(self):
        """Create optimizer with layer-wise LR decay"""
        return torch.optim.AdamW(
            self.fsdp_model.parameters(),
            lr=1e-4,
            betas=(0.9, 0.95),
            eps=1e-8,
            weight_decay=0.1
        )
    
    def resume_if_available(self):
        """Resume training from latest checkpoint"""
        import glob
        
        checkpoints = glob.glob(f"{self.checkpoint_dir}/checkpoint_*.pt")
        if not checkpoints:
            return
        
        latest = sorted(checkpoints)[-1]
        
        if self.local_rank == 0:
            print(f"Resuming from {latest}")
        
        checkpoint = torch.load(latest, map_location='cuda')
        
        self.fsdp_model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scaler.load_state_dict(checkpoint['scaler'])
        
        self.step = checkpoint['step']
        self.global_step = checkpoint['global_step']

Training Loop with Custom Components

class TrainingLoop:
    """
    Production training loop with custom architecture support.
    """
    def __init__(self, model, config):
        self.model = model
        self.config = config
        self.step = 0
        
        # Custom components
        self.grad_monitor = GradientMonitor()
        self.loss_tracker = LossTracker()
        self.tensorboard_writer = SummaryWriter(config['log_dir'])
    
    def train_step(self, batch):
        """Single training step"""
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        
        # Mixed precision forward pass
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs['loss']
        
        # Check for NaN
        if torch.isnan(loss):
            self.handle_nan_loss(batch, loss)
            return None
        
        # Backward pass with gradient scaling
        self.scaler.scale(loss).backward()
        
        # Monitor gradients
        grad_norm = self.grad_monitor.check(self.model)
        
        if grad_norm > self.config.get('max_grad_norm', 1.0):
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config['max_grad_norm']
            )
        
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()
        
        # Log metrics
        self.log_metrics(loss, grad_norm)
        
        self.step += 1
        return {'loss': loss.item(), 'grad_norm': grad_norm}
    
    def handle_nan_loss(self, batch, loss):
        """Handle NaN loss with diagnostic and recovery"""
        if self.local_rank == 0:
            print(f"NaN detected at step {self.step}")
            
            # Save diagnostic checkpoint
            torch.save({
                'step': self.step,
                'loss': loss.item(),
                'batch': {k: v.cpu() for k, v in batch.items()},
                'model_state': self.model.state_dict(),
            }, f"{self.checkpoint_dir}/nan_diagnostic_{self.step}.pt")
            
            # Analyze gradients
            self.grad_monitor.analyze(self.model)
    
    def log_metrics(self, loss, grad_norm):
        """Log training metrics"""
        if self.local_rank == 0:
            self.tensorboard_writer.add_scalar('loss', loss.item(), self.step)
            self.tensorboard_writer.add_scalar('grad_norm', grad_norm, self.step)
            self.tensorboard_writer.add_scalar('lr', self.get_lr(), self.step)
    
    def get_lr(self):
        for param_group in self.optimizer.param_groups:
            return param_group['lr']


class GradientMonitor:
    """Monitor gradient health for custom architectures"""
    
    def __init__(self):
        self.history = []
        self.alert_threshold = 10.0
    
    def check(self, model):
        total_norm = 0.0
        for p in model.parameters():
            if p.grad is not None:
                param_norm = p.grad.data.norm(2)
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        
        self.history.append(total_norm)
        
        if total_norm > self.alert_threshold:
            print(f"WARNING: Large gradient norm: {total_norm:.2f}")
        
        return total_norm
    
    def analyze(self, model):
        """Analyze gradient distribution across layers"""
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad = param.grad.data
                print(f"{name}: mean={grad.mean():.6f}, std={grad.std():.6f}, "
                      f"max={grad.abs().max():.4f}")

Checkpoint Management

import os
import glob

class CheckpointManager:
    """
    Manage model checkpoints with async saving.
    Critical for long training runs with custom architectures.
    """
    def __init__(self, save_dir, max_checkpoints=5):
        self.save_dir = save_dir
        self.max_checkpoints = max_checkpoints
        os.makedirs(save_dir, exist_ok=True)
    
    def save_checkpoint(self, model, optimizer, scaler, step, metadata=None):
        """Save checkpoint with model sharding for FSDP"""
        if dist.get_rank() == 0:  # Only save from rank 0
            checkpoint_path = f"{self.save_dir}/checkpoint_{step}.pt"
            
            checkpoint = {
                'step': step,
                'metadata': metadata or {},
            }
            
            # For FSDP, need to gather state dict
            # This is simplified - production code needs proper handling
            if hasattr(model, 'module'):
                # DDP/FSDP wrapped
                model_state = model.module.state_dict()
            else:
                model_state = model.state_dict()
            
            checkpoint['model'] = model_state
            checkpoint['optimizer'] = optimizer.state_dict()
            checkpoint['scaler'] = scaler.state_dict()
            
            torch.save(checkpoint, checkpoint_path)
            
            # Clean up old checkpoints
            self._cleanup_old_checkpoints()
            
            return checkpoint_path
        return None
    
    def _cleanup_old_checkpoints(self):
        """Remove old checkpoints, keeping max_checkpoints most recent"""
        checkpoints = glob.glob(f"{self.save_dir}/checkpoint_*.pt")
        if len(checkpoints) > self.max_checkpoints:
            checkpoints.sort(key=os.path.getmtime)
            for old in checkpoints[:-self.max_checkpoints]:
                os.remove(old)
    
    def load_checkpoint(self, path, model, optimizer=None, scaler=None):
        """Load checkpoint"""
        checkpoint = torch.load(path, map_location='cuda')
        
        if hasattr(model, 'module'):
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])
        
        if optimizer and 'optimizer' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
        
        if scaler and 'scaler' in checkpoint:
            scaler.load_state_dict(checkpoint['scaler'])
        
        return checkpoint['step']

Failure Mode: Memory Fragmentation

# BUG: Creating tensors without proper cleanup causes OOM
class MemoryLeakTrainer:
    def train_step(self, batch):
        # Creating new tensors in loop without deleting
        for i in range(100):
            intermediate = torch.randn(1000, 1000, device='cuda')
            # ... compute ...
            # intermediate not deleted, memory accumulates
        
        # After 100 steps, OOM

# FIX: Proper memory management
class MemorySafeTrainer:
    def train_step(self, batch):
        intermediate = None
        try:
            for i in range(100):
                intermediate = torch.randn(1000, 1000, device='cuda')
                # ... compute ...
        finally:
            del intermediate
            torch.cuda.empty_cache()
        
        # Or use factory method to reuse buffer
        buffer = torch.zeros(1000, 1000, device='cuda')
        for i in range(100):
            buffer.random_()  # Reuse existing allocation
EXERCISE

Implement a training loop that handles custom attention patterns with gradient checkpointing. Benchmark the memory savings and throughput impact of different gradient checkpointing strategies.

← Chapter 21
Model Sizing
Chapter 23 →
Architecture Documentation