24. Custom Architecture Project

Chapter 24 of 24 · 35 min

This final chapter synthesizes all previous content into a complete architecture design project. Build a custom transformer from scratch with documented decisions.

Project Specification

"""
Custom Architecture Project: LongContext Transformer

Goal: Design an architecture optimized for 100K+ token contexts
Target use case: Document understanding, long-range reasoning

Design Requirements:
1. Memory efficiency: Must fit 100K context in 80GB GPU
2. Quality: Comparable to full attention on standard benchmarks
3. Speed: <2x slowdown vs 4K context for 100K generation
4. Training stability: Must train to completion without NaN

Design Decisions (to be made):
- Attention pattern: sliding + global + dilated?
- KV head configuration: how many for memory/quality balance?
- Depth vs width: given 7B parameter budget
- RoPE vs ALiBi: which positional encoding for long context?
"""

from dataclasses import dataclass, field
from typing import Optional, List, Tuple
import torch
import torch.nn as nn

@dataclass
class LongContextConfig:
    """Configuration for LongContext Transformer"""
    # Model dimensions
    d_model: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: int = 8
    d_ff: int = 11008
    
    # Context configuration
    max_seq_len: int = 102400
    sliding_window_size: int = 4096
    global_window_size: int = 128  # Global tokens every N tokens
    
    # Vocabulary
    vocab_size: int = 32000
    
    # Positional encoding
    rope_theta: float = 10000.0
    rope_scaling: Optional[List[Tuple[int, float]]] = None  # [(layer, factor), ...]
    
    # Initialization
    init_method: str = 'kaiming'
    init_std: float = 0.02
    
    # Validation
    def validate(self):
        assert self.d_model % self.n_heads == 0
        assert self.n_heads % self.n_kv_heads == 0
        assert self.sliding_window_size <= self.max_seq_len
        return True

Implementation: LongContext Transformer

class LongContextTransformer(nn.Module):
    """
    Transformer architecture optimized for long contexts.
    
    Key features:
    - Sliding window attention for local patterns
    - Global attention for key positions
    - Dilated attention layers for long-range dependencies
    - GQA for memory efficiency
    - Scaled RoPE for context extension
    
    Architecture:
    - Layers 0-15: Sliding window (4096) + local global
    - Layers 16-27: Dilated attention (dilation=4)
    - Layers 28-31: Full attention for final aggregation
    """
    
    def __init__(self, config: LongContextConfig):
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        
        # Transformer layers with custom attention
        self.layers = nn.ModuleList([
            self._create_layer(i, config)
            for i in range(config.n_layers)
        ])
        
        # Output projection
        self.norm = RMSNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        # Tie weights if appropriate
        self.lm_head.weight = self.embed_tokens.weight
        
        # Initialize
        self.apply(self._init_weights)
    
    def _create_layer(self, layer_idx: int, config: LongContextConfig):
        """Create appropriate attention mechanism for layer"""
        
        if layer_idx < 16:
            # Sliding window layers
            return TransformerLayer(
                attention=SlidingWindowAttention(
                    d_model=config.d_model,
                    n_heads=config.n_heads,
                    n_kv_heads=config.n_kv_heads,
                    window_size=config.sliding_window_size
                ),
                feed_forward=SwiGLUFeedForward(config.d_model, config.d_ff),
                layer_idx=layer_idx
            )
        elif layer_idx < 28:
            # Dilated attention layers
            dilation = 4
            return TransformerLayer(
                attention=DilatedAttention(
                    d_model=config.d_model,
                    n_heads=config.n_heads,
                    n_kv_heads=config.n_kv_heads,
                    window_size=1024,
                    dilation=dilation
                ),
                feed_forward=SwiGLUFeedForward(config.d_model, config.d_ff),
                layer_idx=layer_idx
            )
        else:
            # Full attention for final aggregation
            return TransformerLayer(
                attention=GroupedQueryAttention(
                    d_model=config.d_model,
                    n_heads=config.n_heads,
                    n_kv_heads=config.n_kv_heads
                ),
                feed_forward=SwiGLUFeedForward(config.d_model, config.d_ff),
                layer_idx=layer_idx
            )
    
    def _init_weights(self, module):
        """Initialize weights according to config"""
        if isinstance(module, nn.Linear):
            if self.config.init_method == 'kaiming':
                nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
            elif self.config.init_method == 'xavier':
                nn.init.xavier_normal_(module.weight)
            elif self.config.init_method == 'normal':
                nn.init.normal_(module.weight, std=self.config.init_std)
            
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, std=self.config.init_std)
    
    def forward(self, input_ids: torch.Tensor, 
                attention_mask: Optional[torch.Tensor] = None,
                position_ids: Optional[torch.Tensor] = None,
                kv_cache: Optional[List] = None):
        """
        Forward pass.
        
        Args:
            input_ids: (batch, seq_len) token IDs
            attention_mask: (batch, seq_len) attention mask
            position_ids: (batch, seq_len) position indices
            kv_cache: List of (k, v) tuples per layer
            
        Returns:
            dict with 'logits' and 'loss'
        """
        B, T = input_ids.shape
        
        # Embed tokens
        h = self.embed_tokens(input_ids)
        
        # Create position IDs if not provided
        if position_ids is None:
            position_ids = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
        
        # Apply RoPE
        h = self._apply_rope(h, position_ids)
        
        # Process through layers
        new_kv_cache = []
        for layer_idx, layer in enumerate(self.layers):
            cache = kv_cache[layer_idx] if kv_cache else None
            
            h = layer(h, mask=attention_mask, cache=cache)
            
            if cache is not None:
                new_kv_cache.append((layer.attn.k_cache, layer.attn.v_cache))
        
        # Final norm and projection
        h = self.norm(h)
        logits = self.lm_head(h)
        
        return {'logits': logits, 'kv_cache': new_kv_cache}
    
    def _apply_rope(self, x, position_ids):
        """Apply RoPE to hidden states"""
        # Implementation depends on chosen RoPE variant
        pass


class TransformerLayer(nn.Module):
    """Single transformer layer with pre-norm architecture"""
    
    def __init__(self, attention, feed_forward, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        
        # Pre-norm architecture
        self.attention_norm = RMSNorm(attention.d_model)
        self.ffn_norm = RMSNorm(attention.d_model)
        
        self.attention = attention
        self.feed_forward = feed_forward
    
    def forward(self, x, mask=None, cache=None):
        # Pre-norm for attention
        x = x + self.attention(self.attention_norm(x), mask=mask, cache=cache)
        
        # Pre-norm for FFN
        x = x + self.feed_forward(self.ffn_norm(x))
        
        return x


class RMSNorm(nn.Module):
    """RMS Normalization (better than LayerNorm for transformers)"""
    
    def __init__(self, d_model, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))
    
    def forward(self, x):
        norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
        return x * norm * self.weight

Training Infrastructure

class LongContextTrainer:
    """
    Trainer for LongContext Transformer with custom infrastructure.
    """
    
    def __init__(self, config: LongContextConfig):
        self.config = config
        self.model = LongContextTransformer(config)
        
        # Distributed setup
        self.setup_distributed()
        
        # Optimizer
        self.optimizer = self.create_optimizer()
        
        # Gradient handling
        self.scaler = GradScaler()
        self.grad_accumulation = 4
        
        # Monitoring
        self.metrics = TrainingMetrics()
    
    def setup_distributed(self):
        """Initialize distributed training"""
        # Setup code from Chapter 22
        pass
    
    def create_optimizer(self):
        """Create optimizer with weight decay"""
        return torch.optim.AdamW(
            self.model.parameters(),
            lr=1e-4,
            betas=(0.9, 0.95),
            weight_decay=0.1
        )
    
    def train_step(self, batch):
        """Single training step with gradient accumulation"""
        input_ids = batch['input_ids']
        labels = batch['labels']
        
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            outputs = self.model(input_ids)
            loss = outputs['loss']
            loss = loss / self.grad_accumulation
        
        self.scaler.scale(loss).backward()
        
        # Gradient accumulation step
        if self.step % self.grad_accumulation == 0:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.optimizer.zero_grad()
        
        # Monitor metrics
        self.metrics.record('loss', loss.item() * self.grad_accumulation)
        
        return {'loss': loss.item() * self.grad_accumulation}

Validation and Benchmarking

class LongContextValidator:
    """
    Validation suite for LongContext Transformer.
    """
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
    
    def validate_architecture(self):
        """Validate architectural correctness"""
        tests = [
            self.test_output_shape,
            self.test_attention_scores,
            self.test_kv_cache,
            self.test_gradient_flow,
            self.test_numerical_stability,
        ]
        
        results = []
        for test in tests:
            try:
                result = test()
                results.append({'test': test.__name__, 'passed': True, 'result': result})
            except AssertionError as e:
                results.append({'test': test.__name__, 'passed': False, 'error': str(e)})
        
        return results
    
    def test_output_shape(self):
        """Test output dimensions"""
        x = torch.randint(0, 32000, (2, 1024))
        output = self.model(x)
        
        assert output['logits'].shape == (2, 1024, 32000), \
            f"Expected shape (2, 1024, 32000), got {output['logits'].shape}"
    
    def test_attention_scores(self):
        """Test attention score properties"""
        # Implement attention score validation
        pass
    
    def test_kv_cache(self):
        """Test KV cache correctness"""
        x = torch.randint(0, 32000, (1, 100))
        
        # Generate with cache
        cache = None
        for _ in range(50):
            output = self.model(x, kv_cache=cache)
            cache = output.get('kv_cache')
            
            next_token = output['logits'].argmax(dim=-1)
            x = torch.cat([x, next_token[:, -1:]], dim=1)
        
        # Check that cache is being used (memory should be bounded)
        # Full attention would use ~8GB for 150 tokens
        # GQA should use ~2GB
        assert torch.cuda.max_memory_allocated() < 4e9  # Less than 4GB
    
    def benchmark_long_context(self):
        """Benchmark performance at various context lengths"""
        lengths = [4096, 16384, 65536, 102400]
        results = []
        
        for length in lengths:
            x = torch.randint(0, 32000, (1, length))
            
            torch.cuda.reset_peak_memory_stats()
            start = time.perf_counter()
            
            with torch.no_grad():
                _ = self.model(x)
            
            torch.cuda.synchronize()
            end = time.perf_counter()
            
            memory = torch.cuda.max_memory_allocated() / 1e9
            
            results.append({
                'length': length,
                'time': end - start,
                'memory_gb': memory,
                'throughput': length / (end - start)
            })
        
        return results

Project Submission Checklist

"""
LongContext Transformer Project Checklist

Architecture:
[ ] Configuration schema with validation
[ ] Complete model implementation
[ ] All attention mechanisms implemented
[ ] Positional encoding (RoPE with scaling)
[ ] Weight initialization strategy
[ ] Parameter count documentation

Training:
[ ] Distributed training setup
[ ] Mixed precision training
[ ] Gradient clipping
[ ] Learning rate schedule
[ ] Checkpoint management
[ ] Recovery from failures

Validation:
[ ] Unit tests for all components
[ ] Integration tests
[ ] Benchmark suite
[ ] Memory profiling
[ ] Throughput profiling

Documentation:
[ ] Architecture decision records
[ ] Configuration documentation
[ ] API documentation
[ ] Test documentation
[ ] Performance characterization

Deployment:
[ ] Inference optimization
[ ] KV cache management
[ ] Quantization support
[ ] Production readiness
"""

# Run validation
def run_project_validation(config):
    model = LongContextTransformer(config)
    validator = LongContextValidator(model, tokenizer)
    
    # Architecture tests
    arch_results = validator.validate_architecture()
    
    # Benchmark
    benchmark_results = validator.benchmark_long_context()
    
    # Report
    report = {
        'architecture_tests': arch_results,
        'benchmark': benchmark_results,
        'config': config.to_dict(),
        'param_count': config.estimate_params()
    }
    
    return report
EXERCISE

Complete the LongContext Transformer implementation by:

  1. Implementing the _apply_rope method with scaling
  2. Implementing test_attention_scores and test_gradient_flow validation tests
  3. Running the validation suite and reporting results
  4. Creating ADR documents for 3 key architectural decisions

This completes the Custom LLM Architecture Design course. You now have the tools to design, implement, validate, and document custom transformer architectures optimized for your specific use case.