13. Comparing Architectures
Chapter 13 of 24 · 25 min
Architecture selection is the highest-leverage decision in LLM development. A poor choice compounds through every subsequent training hour and inference cycle.
Transformer Variants
The original Transformer architecture used multi-head attention with full key-value caching. Every head attended to every position, creating O(n²) memory and compute requirements. Modern variants make deliberate trade-offs:
import torch
import torch.nn as nn
class FullAttention(nn.Module):
"""Original O(n²) attention - every token attends to everything"""
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# Full attention matrix: (B, H, T, T)
attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = attn.softmax(dim=-1)
return (attn @ v).transpose(1, 2).reshape(B, T, C)
Grouped Query Attention (GQA) reduces the key-value heads while maintaining many query heads. This dramatically improves KV cache efficiency during generation:
class GroupedQueryAttention(nn.Module):
"""M query heads, N << M key-value heads"""
def __init__(self, d_model, n_query_heads, n_kv_heads, head_dim=64):
super().__init__()
self.n_query_heads = n_query_heads
self.n_kv_heads = n_kv_heads
assert n_query_heads % n_kv_heads == 0
self.q = nn.Linear(d_model, n_query_heads * head_dim)
self.kv = nn.Linear(d_model, 2 * n_kv_heads * head_dim)
self.o = nn.Linear(n_query_heads * head_dim, d_model)
def forward(self, x, cache=None):
B, T, C = x.shape
q = self.q(x).view(B, T, self.n_query_heads, -1)
k, v = self.kv(x).chunk(2, dim=-1)
k, v = k.view(B, T, self.n_kv_heads, -1), v.view(B, T, self.n_kv_heads, -1)
# Repeat k,v across query groups for efficient matmul
q_groups = self.n_query_heads // self.n_kv_heads
k = k.repeat_interleave(q_groups, dim=1)
v = v.repeat_interleave(q_groups, dim=1)
# Standard scaled dot-product attention
scale = k.shape[-1] ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
attn = attn.softmax(dim=-1)
out = torch.matmul(attn, v)
return out.view(B, T, -1)
Comparative Analysis Matrix
| Architecture | Memory (Context) | Speed (Inference) | Quality Loss | Best For |
|---|---|---|---|---|
| Full MHA | O(n²) | Slow | None | Research, small models |
| GQA | O(n × kv_heads) | Fast | Minimal | Production LLMs |
| MQA | O(n × 1) | Fastest | Moderate | Memory-constrained |
| Sliding Window | O(n × w) | Moderate | Local quality | Long documents |
Common Failure Mode: Head Dimension Mismatch
# INCORRECT - crashes at runtime
class BrokenAttention(nn.Module):
def __init__(self, d_model=768, n_heads=12):
super().__init__()
self.head_dim = d_model / n_heads # Float division in Python 3
self.qkv = nn.Linear(d_model, 3 * d_model)
# Later: q.view(B, T, n_heads, self.head_dim)
# ValueError: expected size 12 at dim=2 but got 64.0
# CORRECT
class CorrectAttention(nn.Module):
def __init__(self, d_model=768, n_heads=12):
super().__init__()
self.head_dim = d_model // n_heads # Integer division
self.qkv = nn.Linear(d_model, 3 * d_model)
Benchmarking Comparison
# Compare memory usage of attention variants
python -c "
import torch
from your_module import FullAttention, GroupedQueryAttention
seq_len = 4096
d_model = 4096
full = FullAttention(d_model, 32)
gqa = GroupedQueryAttention(d_model, 32, 8)
x = torch.randn(1, seq_len, d_model)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
full(x)
full_mem = torch.cuda.max_memory_allocated() / 1e9
gqa(x)
gqa_mem = torch.cuda.max_memory_allocated() / 1e9
print(f'Full: {full_mem:.2f}GB, GQA: {gqa_mem:.2f}GB')
print(f'Savings: {(1 - gqa_mem/full_mem)*100:.1f}%')
"
EXERCISE
Implement both FullAttention and GQA variants. Measure inference throughput (tokens/second) for a 2048-token generation run using torch.cuda.Event timing. Document the throughput ratio in a comparison table.