24. Custom Architecture Project
Chapter 24 of 24 · 35 min
This final chapter synthesizes all previous content into a complete architecture design project. Build a custom transformer from scratch with documented decisions.
Project Specification
"""
Custom Architecture Project: LongContext Transformer
Goal: Design an architecture optimized for 100K+ token contexts
Target use case: Document understanding, long-range reasoning
Design Requirements:
1. Memory efficiency: Must fit 100K context in 80GB GPU
2. Quality: Comparable to full attention on standard benchmarks
3. Speed: <2x slowdown vs 4K context for 100K generation
4. Training stability: Must train to completion without NaN
Design Decisions (to be made):
- Attention pattern: sliding + global + dilated?
- KV head configuration: how many for memory/quality balance?
- Depth vs width: given 7B parameter budget
- RoPE vs ALiBi: which positional encoding for long context?
"""
from dataclasses import dataclass, field
from typing import Optional, List, Tuple
import torch
import torch.nn as nn
@dataclass
class LongContextConfig:
"""Configuration for LongContext Transformer"""
# Model dimensions
d_model: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: int = 8
d_ff: int = 11008
# Context configuration
max_seq_len: int = 102400
sliding_window_size: int = 4096
global_window_size: int = 128 # Global tokens every N tokens
# Vocabulary
vocab_size: int = 32000
# Positional encoding
rope_theta: float = 10000.0
rope_scaling: Optional[List[Tuple[int, float]]] = None # [(layer, factor), ...]
# Initialization
init_method: str = 'kaiming'
init_std: float = 0.02
# Validation
def validate(self):
assert self.d_model % self.n_heads == 0
assert self.n_heads % self.n_kv_heads == 0
assert self.sliding_window_size <= self.max_seq_len
return True
Implementation: LongContext Transformer
class LongContextTransformer(nn.Module):
"""
Transformer architecture optimized for long contexts.
Key features:
- Sliding window attention for local patterns
- Global attention for key positions
- Dilated attention layers for long-range dependencies
- GQA for memory efficiency
- Scaled RoPE for context extension
Architecture:
- Layers 0-15: Sliding window (4096) + local global
- Layers 16-27: Dilated attention (dilation=4)
- Layers 28-31: Full attention for final aggregation
"""
def __init__(self, config: LongContextConfig):
super().__init__()
self.config = config
# Token embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
# Transformer layers with custom attention
self.layers = nn.ModuleList([
self._create_layer(i, config)
for i in range(config.n_layers)
])
# Output projection
self.norm = RMSNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Tie weights if appropriate
self.lm_head.weight = self.embed_tokens.weight
# Initialize
self.apply(self._init_weights)
def _create_layer(self, layer_idx: int, config: LongContextConfig):
"""Create appropriate attention mechanism for layer"""
if layer_idx < 16:
# Sliding window layers
return TransformerLayer(
attention=SlidingWindowAttention(
d_model=config.d_model,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
window_size=config.sliding_window_size
),
feed_forward=SwiGLUFeedForward(config.d_model, config.d_ff),
layer_idx=layer_idx
)
elif layer_idx < 28:
# Dilated attention layers
dilation = 4
return TransformerLayer(
attention=DilatedAttention(
d_model=config.d_model,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads,
window_size=1024,
dilation=dilation
),
feed_forward=SwiGLUFeedForward(config.d_model, config.d_ff),
layer_idx=layer_idx
)
else:
# Full attention for final aggregation
return TransformerLayer(
attention=GroupedQueryAttention(
d_model=config.d_model,
n_heads=config.n_heads,
n_kv_heads=config.n_kv_heads
),
feed_forward=SwiGLUFeedForward(config.d_model, config.d_ff),
layer_idx=layer_idx
)
def _init_weights(self, module):
"""Initialize weights according to config"""
if isinstance(module, nn.Linear):
if self.config.init_method == 'kaiming':
nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
elif self.config.init_method == 'xavier':
nn.init.xavier_normal_(module.weight)
elif self.config.init_method == 'normal':
nn.init.normal_(module.weight, std=self.config.init_std)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=self.config.init_std)
def forward(self, input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
kv_cache: Optional[List] = None):
"""
Forward pass.
Args:
input_ids: (batch, seq_len) token IDs
attention_mask: (batch, seq_len) attention mask
position_ids: (batch, seq_len) position indices
kv_cache: List of (k, v) tuples per layer
Returns:
dict with 'logits' and 'loss'
"""
B, T = input_ids.shape
# Embed tokens
h = self.embed_tokens(input_ids)
# Create position IDs if not provided
if position_ids is None:
position_ids = torch.arange(T, device=input_ids.device).unsqueeze(0).expand(B, -1)
# Apply RoPE
h = self._apply_rope(h, position_ids)
# Process through layers
new_kv_cache = []
for layer_idx, layer in enumerate(self.layers):
cache = kv_cache[layer_idx] if kv_cache else None
h = layer(h, mask=attention_mask, cache=cache)
if cache is not None:
new_kv_cache.append((layer.attn.k_cache, layer.attn.v_cache))
# Final norm and projection
h = self.norm(h)
logits = self.lm_head(h)
return {'logits': logits, 'kv_cache': new_kv_cache}
def _apply_rope(self, x, position_ids):
"""Apply RoPE to hidden states"""
# Implementation depends on chosen RoPE variant
pass
class TransformerLayer(nn.Module):
"""Single transformer layer with pre-norm architecture"""
def __init__(self, attention, feed_forward, layer_idx):
super().__init__()
self.layer_idx = layer_idx
# Pre-norm architecture
self.attention_norm = RMSNorm(attention.d_model)
self.ffn_norm = RMSNorm(attention.d_model)
self.attention = attention
self.feed_forward = feed_forward
def forward(self, x, mask=None, cache=None):
# Pre-norm for attention
x = x + self.attention(self.attention_norm(x), mask=mask, cache=cache)
# Pre-norm for FFN
x = x + self.feed_forward(self.ffn_norm(x))
return x
class RMSNorm(nn.Module):
"""RMS Normalization (better than LayerNorm for transformers)"""
def __init__(self, d_model, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
norm = x.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return x * norm * self.weight
Training Infrastructure
class LongContextTrainer:
"""
Trainer for LongContext Transformer with custom infrastructure.
"""
def __init__(self, config: LongContextConfig):
self.config = config
self.model = LongContextTransformer(config)
# Distributed setup
self.setup_distributed()
# Optimizer
self.optimizer = self.create_optimizer()
# Gradient handling
self.scaler = GradScaler()
self.grad_accumulation = 4
# Monitoring
self.metrics = TrainingMetrics()
def setup_distributed(self):
"""Initialize distributed training"""
# Setup code from Chapter 22
pass
def create_optimizer(self):
"""Create optimizer with weight decay"""
return torch.optim.AdamW(
self.model.parameters(),
lr=1e-4,
betas=(0.9, 0.95),
weight_decay=0.1
)
def train_step(self, batch):
"""Single training step with gradient accumulation"""
input_ids = batch['input_ids']
labels = batch['labels']
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
outputs = self.model(input_ids)
loss = outputs['loss']
loss = loss / self.grad_accumulation
self.scaler.scale(loss).backward()
# Gradient accumulation step
if self.step % self.grad_accumulation == 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
# Monitor metrics
self.metrics.record('loss', loss.item() * self.grad_accumulation)
return {'loss': loss.item() * self.grad_accumulation}
Validation and Benchmarking
class LongContextValidator:
"""
Validation suite for LongContext Transformer.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def validate_architecture(self):
"""Validate architectural correctness"""
tests = [
self.test_output_shape,
self.test_attention_scores,
self.test_kv_cache,
self.test_gradient_flow,
self.test_numerical_stability,
]
results = []
for test in tests:
try:
result = test()
results.append({'test': test.__name__, 'passed': True, 'result': result})
except AssertionError as e:
results.append({'test': test.__name__, 'passed': False, 'error': str(e)})
return results
def test_output_shape(self):
"""Test output dimensions"""
x = torch.randint(0, 32000, (2, 1024))
output = self.model(x)
assert output['logits'].shape == (2, 1024, 32000), \
f"Expected shape (2, 1024, 32000), got {output['logits'].shape}"
def test_attention_scores(self):
"""Test attention score properties"""
# Implement attention score validation
pass
def test_kv_cache(self):
"""Test KV cache correctness"""
x = torch.randint(0, 32000, (1, 100))
# Generate with cache
cache = None
for _ in range(50):
output = self.model(x, kv_cache=cache)
cache = output.get('kv_cache')
next_token = output['logits'].argmax(dim=-1)
x = torch.cat([x, next_token[:, -1:]], dim=1)
# Check that cache is being used (memory should be bounded)
# Full attention would use ~8GB for 150 tokens
# GQA should use ~2GB
assert torch.cuda.max_memory_allocated() < 4e9 # Less than 4GB
def benchmark_long_context(self):
"""Benchmark performance at various context lengths"""
lengths = [4096, 16384, 65536, 102400]
results = []
for length in lengths:
x = torch.randint(0, 32000, (1, length))
torch.cuda.reset_peak_memory_stats()
start = time.perf_counter()
with torch.no_grad():
_ = self.model(x)
torch.cuda.synchronize()
end = time.perf_counter()
memory = torch.cuda.max_memory_allocated() / 1e9
results.append({
'length': length,
'time': end - start,
'memory_gb': memory,
'throughput': length / (end - start)
})
return results
Project Submission Checklist
"""
LongContext Transformer Project Checklist
Architecture:
[ ] Configuration schema with validation
[ ] Complete model implementation
[ ] All attention mechanisms implemented
[ ] Positional encoding (RoPE with scaling)
[ ] Weight initialization strategy
[ ] Parameter count documentation
Training:
[ ] Distributed training setup
[ ] Mixed precision training
[ ] Gradient clipping
[ ] Learning rate schedule
[ ] Checkpoint management
[ ] Recovery from failures
Validation:
[ ] Unit tests for all components
[ ] Integration tests
[ ] Benchmark suite
[ ] Memory profiling
[ ] Throughput profiling
Documentation:
[ ] Architecture decision records
[ ] Configuration documentation
[ ] API documentation
[ ] Test documentation
[ ] Performance characterization
Deployment:
[ ] Inference optimization
[ ] KV cache management
[ ] Quantization support
[ ] Production readiness
"""
# Run validation
def run_project_validation(config):
model = LongContextTransformer(config)
validator = LongContextValidator(model, tokenizer)
# Architecture tests
arch_results = validator.validate_architecture()
# Benchmark
benchmark_results = validator.benchmark_long_context()
# Report
report = {
'architecture_tests': arch_results,
'benchmark': benchmark_results,
'config': config.to_dict(),
'param_count': config.estimate_params()
}
return report
EXERCISE
Complete the LongContext Transformer implementation by:
- Implementing the
_apply_ropemethod with scaling - Implementing
test_attention_scoresandtest_gradient_flowvalidation tests - Running the validation suite and reporting results
- Creating ADR documents for 3 key architectural decisions
This completes the Custom LLM Architecture Design course. You now have the tools to design, implement, validate, and document custom transformer architectures optimized for your specific use case.