RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /Custom LLM Architecture Design
  6. /Ch. 15
Custom LLM Architecture Design

15. Grouped Query Attention

Chapter 15 of 24 · 25 min
KEY INSIGHT

GQA provides the best quality/memory tradeoff for production models. The standard configuration is 8 KV heads for models up to 70B parameters, scaling to 4 KV heads for larger models. The quality degradation from reducing KV heads below 4 is typically unacceptable.

Grouped Query Attention (GQA) emerged as the dominant architecture choice for production LLMs. Understanding its implementation and trade-offs is essential.

GQA Implementation Details

import torch
import torch.nn as nn
import torch.nn.functional as F

class GroupedQueryAttentionV2(nn.Module):
    """
    Production GQA implementation with inference optimization.
    Key insight: During generation, KV cache dominates memory.
    Reducing KV heads from 32 to 8 saves 4x KV memory.
    """
    def __init__(self, d_model, n_query_heads, n_kv_heads, head_dim=128):
        super().__init__()
        assert n_query_heads % n_kv_heads == 0
        self.n_query_heads = n_query_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        
        # Query projections: full width for all query heads
        self.q_proj = nn.Linear(d_model, n_query_heads * head_dim, bias=False)
        
        # Key-Value projections: narrower, shared across query groups
        self.k_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False)
        
        # Output projection
        self.o_proj = nn.Linear(n_query_heads * head_dim, d_model, bias=False)
        
        self.rope = None  # Will be set externally
    
    def forward(self, x, kv_cache=None):
        B, T, C = x.shape
        
        # Compute Q, K, V
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        
        # Reshape to (B, T, n_heads, head_dim)
        q = q.view(B, T, self.n_query_heads, self.head_dim)
        k = k.view(B, T, self.n_kv_heads, self.head_dim)
        v = v.view(B, T, self.n_kv_heads, self.head_dim)
        
        # Apply RoPE to queries and keys
        if self.rope is not None:
            q = self.rope(q, sin, cos)
            k = self.rope(k, sin, cos)
        
        # Handle KV cache during generation
        if kv_cache is not None:
            past_k, past_v = kv_cache
            k = torch.cat([past_k, k], dim=1)
            v = torch.cat([past_v, v], dim=1)
        
        # Expand K,V to match Q heads for efficient computation
        # This replicates KV pairs q_groups times rather than computing separately
        q_groups = self.n_query_heads // self.n_kv_heads
        k = k.repeat_interleave(q_groups, dim=2)  # (B, T, n_q_heads, H)
        v = v.repeat_interleave(q_groups, dim=2)
        
        # Compute attention
        scale = self.head_dim ** -0.5
        attn = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale
        attn = F.softmax(attn, dim=-1)
        
        out = torch.einsum('bhqk,bkhd->bqhd', attn, v)
        return out, (k, v)  # Return updated cache

KV Cache Management

class KVCacheManager:
    """
    Manages KV cache with memory-efficient storage.
    GQA allows ~4x larger context windows vs MHA at same memory.
    """
    def __init__(self, n_kv_heads, head_dim, max_seq_len, dtype=torch.bfloat16):
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim
        self.max_seq_len = max_seq_len
        
        # Pre-allocate pinned memory for fast transfer
        self.k_cache = torch.zeros(
            1, max_seq_len, n_kv_heads, head_dim, 
            dtype=dtype, pin_memory=True
        )
        self.v_cache = torch.zeros(
            1, max_seq_len, n_kv_heads, head_dim,
            dtype=dtype, pin_memory=True
        )
    
    def update(self, pos, k_new, v_new):
        """Update cache at position pos"""
        self.k_cache[0, pos:pos+k_new.shape[1]] = k_new.squeeze(0)
        self.v_cache[0, pos:pos+v_new.shape[1]] = v_new.squeeze(0)
    
    def get_cache(self):
        return self.k_cache, self.v_cache
    
    def memory_usage_gb(self):
        per_token = self.n_kv_heads * self.head_dim * 2 * 2  # K+V, bytes
        total = per_token * self.max_seq_len / 1e9
        return total

Memory Comparison: MHA vs GQA vs MQA

def compare_attention_memory(d_model, seq_len, n_heads=32):
    """Compare memory requirements for different attention types"""
    
    head_dim = d_model // n_heads
    
    # Multi-Head Attention (all heads in KV cache)
    mha_kv = n_heads * head_dim * 2  # bytes per position
    
    # Grouped Query Attention (8 KV heads)
    gqa_kv = 8 * head_dim * 2
    
    # Multi-Query Attention (1 KV head)
    mqa_kv = 1 * head_dim * 2
    
    # Memory for full context
    mha_total = mha_kv * seq_len
    gqa_total = gqa_kv * seq_len
    mqa_total = mqa_kv * seq_len
    
    print(f"Sequence length: {seq_len}")
    print(f"MHA:  {mha_total/1e9:.2f} GB")
    print(f"GQA:  {gqa_total/1e9:.2f} GB ({mha_total/gqa_total:.1f}x reduction)")
    print(f"MQA:  {mqa_total/1e9:.2f} GB ({mha_total/mqa_total:.1f}x reduction)")

# Example: 4096 sequence, 32 heads, head_dim 128
compare_attention_memory(4096, 4096, 32)
# MHA:  4.00 GB
# GQA:  1.00 GB (4.0x reduction)  
# MQA:  0.25 GB (16.0x reduction)

Failure Mode: Incorrect Group Expansion

# BUG: Incorrect dimension for repeat_interleave
class BrokenGQA(nn.Module):
    def __init__(self, n_query=32, n_kv=8):
        super().__init__()
        self.n_query = n_query
        self.n_kv = n_kv
        # ...
    
    def forward(self, q, k, v):
        # q: (B, T, 32, H), k,v: (B, T, 8, H)
        
        # WRONG: repeat on wrong dimension
        k_expanded = k.repeat_interleave(self.n_query, dim=2)
        # Dimension 2 goes from 8 to 32*8=256 instead of 32
        
        # Should repeat on dim=1 (heads) not dim=2 (sequence)
        # Or better: use repeat_interleave(q_groups, dim=2)
EXERCISE

Implement GQA with dynamic KV head count (configurable at init). Benchmark memory and throughput for configurations: (32Q, 32K), (32Q, 8K), (32Q, 4K), (32Q, 1K). Plot the memory-throughput tradeoff curve.

← Chapter 14
Custom Attention Patterns
Chapter 16 →
Sliding Window Attention