19. Architecture Prototyping
Chapter 19 of 24 · 25 min
Prototyping custom architectures before committing to full-scale training prevents catastrophic waste. A well-designed prototype reveals failure modes early.
Prototype Framework
import torch
import torch.nn as nn
from typing import Dict, List, Optional
import math
class ArchitecturePrototype:
"""
Framework for rapid architecture prototyping.
Test architectural ideas at small scale before committing.
"""
def __init__(self, seq_len=512, batch_size=4):
self.seq_len = seq_len
self.batch_size = batch_size
self.results = {}
def create_small_variant(self, config: Dict) -> nn.Module:
"""
Create scaled-down version of architecture for testing.
Scale by reducing layers and hidden size.
"""
small_config = {
'd_model': config['d_model'] // 4,
'n_layers': max(2, config['n_layers'] // 4),
'n_heads': config['n_heads'] // 4,
'd_ff': config['d_ff'] // 4,
}
return self.build_architecture(small_config)
def build_architecture(self, config: Dict) -> nn.Module:
"""Build architecture from config"""
model = PrototypeTransformer(config)
return model
def run_forward_pass(self, model: nn.Module, iterations=50) -> Dict:
"""Benchmark forward pass"""
device = next(model.parameters()).device
x = torch.randint(0, 32000, (self.batch_size, self.seq_len), device=device)
# Warmup
for _ in range(10):
_ = model(x)
# Benchmark
times = []
memory_usages = []
for _ in range(iterations):
torch.cuda.reset_peak_memory_stats()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
_ = model(x)
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))
memory_usages.append(torch.cuda.max_memory_allocated())
return {
'mean_time_ms': sum(times) / len(times),
'p50_time_ms': sorted(times)[len(times)//2],
'peak_memory_gb': max(memory_usages) / 1e9
}
def run_training_step(self, model: nn.Module, iterations=10) -> Dict:
"""Test a single training step"""
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
losses = []
grads = []
x = torch.randint(0, 32000, (self.batch_size, self.seq_len), device=next(model.parameters()).device)
y = torch.randint(0, 32000, (self.batch_size, self.seq_len), device=next(model.parameters()).device)
for i in range(iterations):
optimizer.zero_grad()
output = model(x, labels=y)
loss = output['loss']
if torch.isnan(loss):
return {'status': 'NaN', 'step': i}
loss.backward()
total_grad = sum(p.grad.norm() for p in model.parameters() if p.grad is not None)
grads.append(total_grad.item())
optimizer.step()
losses.append(loss.item())
return {
'status': 'success',
'losses': losses,
'final_loss': losses[-1],
'max_grad': max(grads),
'mean_grad': sum(grads) / len(grads)
}
class PrototypeTransformer(nn.Module):
"""Simplified transformer for rapid prototyping"""
def __init__(self, config: Dict):
super().__init__()
d_model = config['d_model']
n_layers = config['n_layers']
n_heads = config['n_heads']
self.layers = nn.ModuleList([
PrototypeTransformerBlock(d_model, n_heads, config.get('d_ff', d_model * 4))
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, 32000)
def forward(self, x, labels=None):
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100
)
return {'logits': logits, 'loss': loss}
class PrototypeTransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.GELU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
# Pre-norm architecture
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
x = x + self.ffn(self.norm2(x))
return x
Automated Hyperparameter Search
def search_architecture_space(search_space: Dict, budget=10) -> Dict:
"""
Automated search over architecture hyperparameter space.
Useful for finding good configurations before full training.
"""
results = []
configs_tested = 0
while configs_tested < budget:
# Random sample from search space
config = {
'd_model': random.choice(search_space['d_model']),
'n_layers': random.choice(search_space['n_layers']),
'n_heads': random.choice(search_space['n_heads']),
'd_ff_ratio': random.choice(search_space['d_ff_ratio']),
}
prototype = ArchitecturePrototype()
model = prototype.create_small_variant(config)
# Test training stability
training_result = prototype.run_training_step(model)
forward_result = prototype.run_forward_pass(model)
results.append({
'config': config,
'training': training_result,
'forward': forward_result,
'score': score_config(training_result, forward_result)
})
configs_tested += 1
return sorted(results, key=lambda x: x['score'], reverse=True)
def score_config(training_result, forward_result):
"""
Score configuration based on stability and efficiency.
Higher is better.
"""
if training_result['status'] != 'success':
return -1000 # Penalize NaN
# Reward low loss, low memory, stable gradients
score = 0
score -= training_result['final_loss'] # Lower is better
score -= forward_result['peak_memory_gb'] * 0.1
score += 100 if training_result['max_grad'] < 10 else -10
return score
search_space = {
'd_model': [256, 512, 768, 1024],
'n_layers': [2, 4, 8, 12],
'n_heads': [4, 8, 16],
'd_ff_ratio': [2, 4, 8]
}
Failure Mode: Prototype Not Representative
# BUG: Testing at small scale doesn't predict large-scale behavior
class UnrepresentativePrototype:
"""
Small prototype that masks issues present at scale.
Common issues:
- Batch norm statistics unreliable with small batches
- Gradient accumulation effects not captured
- Memory patterns different at scale
"""
def test_architecture(self):
# Testing at seq_len=128 won't reveal 16K context issues
seq_len = 128 # Too short
# Small batch masks memory issues
batch_size = 2 # Too small
# Only testing 100 steps misses long-term instability
steps = 100 # Too few
# Results look great, but full-scale training fails
# FIX: Design representative prototype
class RepresentativePrototype:
def test_architecture(self):
# Test at maximum expected seq_len
seq_len = 2048 # Representative of production
# Test with realistic batch size (possibly simulated via grad accumulation)
effective_batch = 32 # Equivalent batch, tested via accumulation
# Test for minimum 1000 steps to catch instability
steps = 1000 # Catch loss spikes
# Test with realistic data distribution
EXERCISE
Design a prototype that tests sliding window attention stability at sequence lengths of 512, 2048, and 8192. Measure gradient norms at each length and identify the sequence length where instability first appears.