23. Architecture Documentation
Chapter 23 of 24 · 30 min
Documentation converts implicit knowledge into explicit specifications. Without it, architectural decisions become mysterious and unmaintainable.
Architecture Decision Records
# Documentation should capture WHY, not just WHAT
"""
Architecture Decision Record: Grouped Query Attention
Date: 2025-11-15
Status: Approved
Authors: Architecture Team
Context:
- Model target: 70B parameters
- Context window: 32K tokens
- Memory constraint: 8x A100 per replica
Decision:
Implement Grouped Query Attention with 8 KV heads instead of 32.
Consequences:
- KV cache memory reduced by 4x (8.0 GB -> 2.0 GB for 32K context)
- Slight quality degradation: ~0.5% on MMLU (acceptable)
- Inference throughput improvement: ~2.5x (longer context, more batching)
- Implementation complexity: moderate (requires cache management)
Alternatives Considered:
1. Multi-Query Attention (1 KV head)
- Memory savings: 8x
- Quality loss: ~2% (unacceptable for our use case)
2. Flash Attention with full KV cache
- No memory reduction for caching
- Flash Attention helps with compute, not storage
3. Paged Attention (vLLM style)
- Complementary to GQA, not mutually exclusive
- Recommended: combine with GQA for production
References:
- Ainslie et al. 2023: GQA paper
- Llama 2 technical report
- Internal benchmarking results (doc/b1/benchmarks)
"""
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention implementation.
Configuration:
- n_query_heads: 32 (full query capacity)
- n_kv_heads: 8 (4x reduction from standard 32)
Memory Analysis:
- Per-token KV cache: 8 * 128 * 2 * 2 bytes = 4KB (vs 16KB for MHA)
- Full 32K context: 128MB (vs 512MB for MHA)
Quality Impact:
- Benchmarked on internal eval set: -0.3% accuracy
- Impact concentrated in long-range dependency tasks
- Mitigation: increased depth (40 layers vs 32)
See: docs/decisions/gqa_implementation.md for detailed analysis
"""
Configuration Schema
# Configuration should be self-documenting
from dataclasses import dataclass, field
from typing import List, Optional, Dict
import json
@dataclass
class ArchitectureConfig:
"""
Complete architecture configuration with validation.
Example YAML:
```yaml
architecture:
d_model: 4096
n_layers: 32
n_heads: 32
n_kv_heads: 8
d_ff: 11008
vocab_size: 32000
max_seq_len: 32768
rope_theta: 10000
```
"""
d_model: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: int = 8
d_ff: int = 11008
vocab_size: int = 32000
max_seq_len: int = 32768
rope_theta: float = 10000.0
# Optional components
attention_dropout: float = 0.0
ffn_dropout: float = 0.0
use_gated_geglu: bool = True
use_rms_norm: bool = True
# Advanced options
custom_attention: Optional[str] = None
init_method: str = 'kaiming'
def __post_init__(self):
"""Validate configuration"""
self._validate_dimensions()
self._validate_attention_config()
self._validate_ffn_config()
def _validate_dimensions(self):
if self.d_model % self.n_heads != 0:
raise ValueError(
f"d_model ({self.d_model}) must be divisible by n_heads ({self.n_heads})"
)
if self.d_model % self.n_kv_heads != 0:
raise ValueError(
f"d_model ({self.d_model}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
def _validate_attention_config(self):
if self.n_heads % self.n_kv_heads != 0:
raise ValueError(
f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
)
if self.attention_dropout < 0 or self.attention_dropout > 1:
raise ValueError("attention_dropout must be in [0, 1]")
def _validate_ffn_config(self):
if self.d_ff <= 0:
raise ValueError("d_ff must be positive")
def to_dict(self) -> Dict:
"""Export configuration as dictionary"""
return {
'd_model': self.d_model,
'n_layers': self.n_layers,
'n_heads': self.n_heads,
'n_kv_heads': self.n_kv_heads,
'd_ff': self.d_ff,
'vocab_size': self.vocab_size,
'max_seq_len': self.max_seq_len,
'rope_theta': self.rope_theta,
'attention_dropout': self.attention_dropout,
'ffn_dropout': self.ffn_dropout,
'use_gated_geglu': self.use_gated_geglu,
'use_rms_norm': self.use_rms_norm,
'custom_attention': self.custom_attention,
'init_method': self.init_method,
}
def to_yaml(self) -> str:
"""Export as YAML string"""
import yaml
return yaml.dump(self.to_dict(), default_flow_style=False)
def estimate_params(self) -> int:
"""Estimate total parameters"""
# See Chapter 21 for implementation
pass
@classmethod
def from_json(cls, json_str: str) -> 'ArchitectureConfig':
"""Load from JSON string"""
return cls(**json.loads(json_str))
@classmethod
def from_yaml(cls, yaml_str: str) -> 'ArchitectureConfig':
"""Load from YAML string"""
import yaml
return cls(**yaml.safe_load(yaml_str))
Component Documentation Template
"""
Component: Sliding Window Attention
File: src/attention/sliding_window.py
Description:
Sliding window attention with configurable window size.
Each token attends to a fixed-size window centered on its position.
This implementation uses chunked processing to maintain O(T * W) complexity
where W is window size, rather than O(T^2).
Used in layers 0-15 for local pattern recognition.
Layers 16-31 use full attention (no sliding) for global context.
Configuration:
window_size: 512 (default)
- 512 tokens on each side
- Total receptive field per layer: 1025 tokens
dilation: 1 (default)
- Set to 2, 4, 8 in subsequent layers for dilated attention
- Increases effective receptive field without compute increase
Implementation Notes:
- Mask is computed per-chunk, not per-sequence (memory efficient)
- Supports variable-length sequences via cu_seqlens
- Compatible with Flash Attention backend when available
Failure Modes:
1. Window overflow at sequence boundaries
- Fixed: padding with zeros, masking in softmax
2. Dilation causing gaps in attention
- Fixed: ensure dilation * window_size >= max_position
3. Gradient instability with large dilation
- Fixed: gradient clipping at 1.0, warmup 2000 steps
Performance:
- Memory: O(T * W) vs O(T^2) for full attention
- Throughput: ~3x faster than full attention for T=4096, W=512
- Quality: Equivalent to full attention for NLP tasks (local dependencies)
- Quality loss: ~1% on tasks requiring long-range reasoning
"""
class SlidingWindowAttention(nn.Module):
# Implementation...
API Documentation
"""
API: CustomAttentionFactory
Factory for creating attention mechanisms with validated configurations.
Methods:
create(config: AttentionConfig) -> nn.Module
Create attention module from configuration.
Args:
config: AttentionConfig with type, parameters, and options
Returns:
Attention module (FullAttention, GQA, SlidingWindow, etc.)
Raises:
ValueError: If configuration is invalid
NotImplementedError: If attention type not supported
Example:
>>> config = AttentionConfig(type='gqa', n_heads=32, n_kv_heads=8)
>>> attn = CustomAttentionFactory.create(config)
>>> output = attn(x, kv_cache=None)
register_variant(name: str, cls: Type)
Register custom attention variant.
Args:
name: Identifier for the variant
cls: Attention module class
Example:
>>> CustomAttentionFactory.register_variant('my_attention', MyAttention)
list_variants() -> List[str]
List all available attention variants.
Configuration Schema:
AttentionConfig:
type: Literal['full', 'gqa', 'mqa', 'sliding', 'dilated']
n_heads: int (required)
n_kv_heads: int (optional, default=n_heads)
window_size: int (optional, for sliding window)
head_dim: int (optional, default=128)
dropout: float (optional, default=0.0)
"""
class AttentionConfig:
"""Configuration for attention mechanisms."""
SUPPORTED_TYPES = ['full', 'gqa', 'mqa', 'sliding', 'dilated']
def __init__(self, type: str, n_heads: int, **kwargs):
if type not in self.SUPPORTED_TYPES:
raise ValueError(f"type must be one of {self.SUPPORTED_TYPES}")
self.type = type
self.n_heads = n_heads
# ... validate and store other parameters
Testing Documentation
"""
Test Suite: Custom Attention Mechanisms
Location: tests/test_custom_attention.py
Coverage:
1. Forward pass correctness
- Output shape matches expected dimensions
- Attention scores sum to 1 (row-wise)
- Deterministic behavior with fixed seed
2. Backward pass correctness
- Gradients exist for all learnable parameters
- Gradient norms are finite (not NaN/Inf)
- Gradient shapes match parameter shapes
3. Numerical stability
- No NaN in forward/backward for valid inputs
- No Inf in attention scores
- Stable for extreme input values
4. Memory efficiency
- Memory usage scales correctly with sequence length
- KV cache correctly managed
- No memory leaks in repeated calls
5. Integration
- Works within full transformer stack
- Compatible with FSDP training
- Compatible with generation loop
Test Fixtures:
- small_config: d_model=256, n_layers=2 (unit tests)
- medium_config: d_model=1024, n_layers=8 (integration tests)
- large_config: d_model=4096, n_layers=32 (performance tests)
Run Tests:
pytest tests/test_custom_attention.py -v
pytest tests/test_custom_attention.py -v --benchmark
"""
EXERCISE
Create an ADR document for a custom architectural choice you would make (e.g., custom activation function, unusual depth/width ratio, or novel attention pattern). Include context, decision, consequences, and alternatives considered.