19. Architecture Prototyping

Chapter 19 of 24 · 25 min

Prototyping custom architectures before committing to full-scale training prevents catastrophic waste. A well-designed prototype reveals failure modes early.

Prototype Framework

import torch
import torch.nn as nn
from typing import Dict, List, Optional
import math

class ArchitecturePrototype:
    """
    Framework for rapid architecture prototyping.
    Test architectural ideas at small scale before committing.
    """
    def __init__(self, seq_len=512, batch_size=4):
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.results = {}
    
    def create_small_variant(self, config: Dict) -> nn.Module:
        """
        Create scaled-down version of architecture for testing.
        Scale by reducing layers and hidden size.
        """
        small_config = {
            'd_model': config['d_model'] // 4,
            'n_layers': max(2, config['n_layers'] // 4),
            'n_heads': config['n_heads'] // 4,
            'd_ff': config['d_ff'] // 4,
        }
        return self.build_architecture(small_config)
    
    def build_architecture(self, config: Dict) -> nn.Module:
        """Build architecture from config"""
        model = PrototypeTransformer(config)
        return model
    
    def run_forward_pass(self, model: nn.Module, iterations=50) -> Dict:
        """Benchmark forward pass"""
        device = next(model.parameters()).device
        
        x = torch.randint(0, 32000, (self.batch_size, self.seq_len), device=device)
        
        # Warmup
        for _ in range(10):
            _ = model(x)
        
        # Benchmark
        times = []
        memory_usages = []
        
        for _ in range(iterations):
            torch.cuda.reset_peak_memory_stats()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            
            start.record()
            _ = model(x)
            end.record()
            
            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))
            memory_usages.append(torch.cuda.max_memory_allocated())
        
        return {
            'mean_time_ms': sum(times) / len(times),
            'p50_time_ms': sorted(times)[len(times)//2],
            'peak_memory_gb': max(memory_usages) / 1e9
        }
    
    def run_training_step(self, model: nn.Module, iterations=10) -> Dict:
        """Test a single training step"""
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
        
        losses = []
        grads = []
        
        x = torch.randint(0, 32000, (self.batch_size, self.seq_len), device=next(model.parameters()).device)
        y = torch.randint(0, 32000, (self.batch_size, self.seq_len), device=next(model.parameters()).device)
        
        for i in range(iterations):
            optimizer.zero_grad()
            output = model(x, labels=y)
            loss = output['loss']
            
            if torch.isnan(loss):
                return {'status': 'NaN', 'step': i}
            
            loss.backward()
            
            total_grad = sum(p.grad.norm() for p in model.parameters() if p.grad is not None)
            grads.append(total_grad.item())
            
            optimizer.step()
            losses.append(loss.item())
        
        return {
            'status': 'success',
            'losses': losses,
            'final_loss': losses[-1],
            'max_grad': max(grads),
            'mean_grad': sum(grads) / len(grads)
        }


class PrototypeTransformer(nn.Module):
    """Simplified transformer for rapid prototyping"""
    def __init__(self, config: Dict):
        super().__init__()
        d_model = config['d_model']
        n_layers = config['n_layers']
        n_heads = config['n_heads']
        
        self.layers = nn.ModuleList([
            PrototypeTransformerBlock(d_model, n_heads, config.get('d_ff', d_model * 4))
            for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, 32000)
    
    def forward(self, x, labels=None):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.lm_head(x)
        
        loss = None
        if labels is not None:
            loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
                ignore_index=-100
            )
        
        return {'logits': logits, 'loss': loss}


class PrototypeTransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
    
    def forward(self, x):
        # Pre-norm architecture
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x

Automated Hyperparameter Search

def search_architecture_space(search_space: Dict, budget=10) -> Dict:
    """
    Automated search over architecture hyperparameter space.
    Useful for finding good configurations before full training.
    """
    results = []
    configs_tested = 0
    
    while configs_tested < budget:
        # Random sample from search space
        config = {
            'd_model': random.choice(search_space['d_model']),
            'n_layers': random.choice(search_space['n_layers']),
            'n_heads': random.choice(search_space['n_heads']),
            'd_ff_ratio': random.choice(search_space['d_ff_ratio']),
        }
        
        prototype = ArchitecturePrototype()
        model = prototype.create_small_variant(config)
        
        # Test training stability
        training_result = prototype.run_training_step(model)
        forward_result = prototype.run_forward_pass(model)
        
        results.append({
            'config': config,
            'training': training_result,
            'forward': forward_result,
            'score': score_config(training_result, forward_result)
        })
        
        configs_tested += 1
    
    return sorted(results, key=lambda x: x['score'], reverse=True)


def score_config(training_result, forward_result):
    """
    Score configuration based on stability and efficiency.
    Higher is better.
    """
    if training_result['status'] != 'success':
        return -1000  # Penalize NaN
    
    # Reward low loss, low memory, stable gradients
    score = 0
    score -= training_result['final_loss']  # Lower is better
    score -= forward_result['peak_memory_gb'] * 0.1
    score += 100 if training_result['max_grad'] < 10 else -10
    
    return score


search_space = {
    'd_model': [256, 512, 768, 1024],
    'n_layers': [2, 4, 8, 12],
    'n_heads': [4, 8, 16],
    'd_ff_ratio': [2, 4, 8]
}

Failure Mode: Prototype Not Representative

# BUG: Testing at small scale doesn't predict large-scale behavior
class UnrepresentativePrototype:
    """
    Small prototype that masks issues present at scale.
    Common issues:
    - Batch norm statistics unreliable with small batches
    - Gradient accumulation effects not captured
    - Memory patterns different at scale
    """
    def test_architecture(self):
        # Testing at seq_len=128 won't reveal 16K context issues
        seq_len = 128  # Too short
        
        # Small batch masks memory issues
        batch_size = 2  # Too small
        
        # Only testing 100 steps misses long-term instability
        steps = 100  # Too few
        
        # Results look great, but full-scale training fails

# FIX: Design representative prototype
class RepresentativePrototype:
    def test_architecture(self):
        # Test at maximum expected seq_len
        seq_len = 2048  # Representative of production
        
        # Test with realistic batch size (possibly simulated via grad accumulation)
        effective_batch = 32  # Equivalent batch, tested via accumulation
        
        # Test for minimum 1000 steps to catch instability
        steps = 1000  # Catch loss spikes
        
        # Test with realistic data distribution
EXERCISE

Design a prototype that tests sliding window attention stability at sequence lengths of 512, 2048, and 8192. Measure gradient norms at each length and identify the sequence length where instability first appears.