13. Comparing Architectures

Chapter 13 of 24 · 25 min

Architecture selection is the highest-leverage decision in LLM development. A poor choice compounds through every subsequent training hour and inference cycle.

Transformer Variants

The original Transformer architecture used multi-head attention with full key-value caching. Every head attended to every position, creating O(n²) memory and compute requirements. Modern variants make deliberate trade-offs:

import torch
import torch.nn as nn

class FullAttention(nn.Module):
    """Original O(n²) attention - every token attends to everything"""
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Full attention matrix: (B, H, T, T)
        attn = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = attn.softmax(dim=-1)
        
        return (attn @ v).transpose(1, 2).reshape(B, T, C)

Grouped Query Attention (GQA) reduces the key-value heads while maintaining many query heads. This dramatically improves KV cache efficiency during generation:

class GroupedQueryAttention(nn.Module):
    """M query heads, N << M key-value heads"""
    def __init__(self, d_model, n_query_heads, n_kv_heads, head_dim=64):
        super().__init__()
        self.n_query_heads = n_query_heads
        self.n_kv_heads = n_kv_heads
        assert n_query_heads % n_kv_heads == 0
        
        self.q = nn.Linear(d_model, n_query_heads * head_dim)
        self.kv = nn.Linear(d_model, 2 * n_kv_heads * head_dim)
        self.o = nn.Linear(n_query_heads * head_dim, d_model)
    
    def forward(self, x, cache=None):
        B, T, C = x.shape
        q = self.q(x).view(B, T, self.n_query_heads, -1)
        k, v = self.kv(x).chunk(2, dim=-1)
        k, v = k.view(B, T, self.n_kv_heads, -1), v.view(B, T, self.n_kv_heads, -1)
        
        # Repeat k,v across query groups for efficient matmul
        q_groups = self.n_query_heads // self.n_kv_heads
        k = k.repeat_interleave(q_groups, dim=1)
        v = v.repeat_interleave(q_groups, dim=1)
        
        # Standard scaled dot-product attention
        scale = k.shape[-1] ** -0.5
        attn = torch.matmul(q, k.transpose(-2, -1)) * scale
        attn = attn.softmax(dim=-1)
        
        out = torch.matmul(attn, v)
        return out.view(B, T, -1)

Comparative Analysis Matrix

Architecture Memory (Context) Speed (Inference) Quality Loss Best For
Full MHA O(n²) Slow None Research, small models
GQA O(n × kv_heads) Fast Minimal Production LLMs
MQA O(n × 1) Fastest Moderate Memory-constrained
Sliding Window O(n × w) Moderate Local quality Long documents

Common Failure Mode: Head Dimension Mismatch

# INCORRECT - crashes at runtime
class BrokenAttention(nn.Module):
    def __init__(self, d_model=768, n_heads=12):
        super().__init__()
        self.head_dim = d_model / n_heads  # Float division in Python 3
        self.qkv = nn.Linear(d_model, 3 * d_model)
        # Later: q.view(B, T, n_heads, self.head_dim) 
        # ValueError: expected size 12 at dim=2 but got 64.0

# CORRECT
class CorrectAttention(nn.Module):
    def __init__(self, d_model=768, n_heads=12):
        super().__init__()
        self.head_dim = d_model // n_heads  # Integer division
        self.qkv = nn.Linear(d_model, 3 * d_model)

Benchmarking Comparison

# Compare memory usage of attention variants
python -c "
import torch
from your_module import FullAttention, GroupedQueryAttention

seq_len = 4096
d_model = 4096

full = FullAttention(d_model, 32)
gqa = GroupedQueryAttention(d_model, 32, 8)

x = torch.randn(1, seq_len, d_model)
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

full(x)
full_mem = torch.cuda.max_memory_allocated() / 1e9

gqa(x)
gqa_mem = torch.cuda.max_memory_allocated() / 1e9

print(f'Full: {full_mem:.2f}GB, GQA: {gqa_mem:.2f}GB')
print(f'Savings: {(1 - gqa_mem/full_mem)*100:.1f}%')
"
EXERCISE

Implement both FullAttention and GQA variants. Measure inference throughput (tokens/second) for a 2048-token generation run using torch.cuda.Event timing. Document the throughput ratio in a comparison table.