15. Grouped Query Attention
Chapter 15 of 24 · 25 min
Grouped Query Attention (GQA) emerged as the dominant architecture choice for production LLMs. Understanding its implementation and trade-offs is essential.
GQA Implementation Details
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttentionV2(nn.Module):
"""
Production GQA implementation with inference optimization.
Key insight: During generation, KV cache dominates memory.
Reducing KV heads from 32 to 8 saves 4x KV memory.
"""
def __init__(self, d_model, n_query_heads, n_kv_heads, head_dim=128):
super().__init__()
assert n_query_heads % n_kv_heads == 0
self.n_query_heads = n_query_heads
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
# Query projections: full width for all query heads
self.q_proj = nn.Linear(d_model, n_query_heads * head_dim, bias=False)
# Key-Value projections: narrower, shared across query groups
self.k_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False)
self.v_proj = nn.Linear(d_model, n_kv_heads * head_dim, bias=False)
# Output projection
self.o_proj = nn.Linear(n_query_heads * head_dim, d_model, bias=False)
self.rope = None # Will be set externally
def forward(self, x, kv_cache=None):
B, T, C = x.shape
# Compute Q, K, V
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Reshape to (B, T, n_heads, head_dim)
q = q.view(B, T, self.n_query_heads, self.head_dim)
k = k.view(B, T, self.n_kv_heads, self.head_dim)
v = v.view(B, T, self.n_kv_heads, self.head_dim)
# Apply RoPE to queries and keys
if self.rope is not None:
q = self.rope(q, sin, cos)
k = self.rope(k, sin, cos)
# Handle KV cache during generation
if kv_cache is not None:
past_k, past_v = kv_cache
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
# Expand K,V to match Q heads for efficient computation
# This replicates KV pairs q_groups times rather than computing separately
q_groups = self.n_query_heads // self.n_kv_heads
k = k.repeat_interleave(q_groups, dim=2) # (B, T, n_q_heads, H)
v = v.repeat_interleave(q_groups, dim=2)
# Compute attention
scale = self.head_dim ** -0.5
attn = torch.einsum('bqhd,bkhd->bhqk', q, k) * scale
attn = F.softmax(attn, dim=-1)
out = torch.einsum('bhqk,bkhd->bqhd', attn, v)
return out, (k, v) # Return updated cache
KV Cache Management
class KVCacheManager:
"""
Manages KV cache with memory-efficient storage.
GQA allows ~4x larger context windows vs MHA at same memory.
"""
def __init__(self, n_kv_heads, head_dim, max_seq_len, dtype=torch.bfloat16):
self.n_kv_heads = n_kv_heads
self.head_dim = head_dim
self.max_seq_len = max_seq_len
# Pre-allocate pinned memory for fast transfer
self.k_cache = torch.zeros(
1, max_seq_len, n_kv_heads, head_dim,
dtype=dtype, pin_memory=True
)
self.v_cache = torch.zeros(
1, max_seq_len, n_kv_heads, head_dim,
dtype=dtype, pin_memory=True
)
def update(self, pos, k_new, v_new):
"""Update cache at position pos"""
self.k_cache[0, pos:pos+k_new.shape[1]] = k_new.squeeze(0)
self.v_cache[0, pos:pos+v_new.shape[1]] = v_new.squeeze(0)
def get_cache(self):
return self.k_cache, self.v_cache
def memory_usage_gb(self):
per_token = self.n_kv_heads * self.head_dim * 2 * 2 # K+V, bytes
total = per_token * self.max_seq_len / 1e9
return total
Memory Comparison: MHA vs GQA vs MQA
def compare_attention_memory(d_model, seq_len, n_heads=32):
"""Compare memory requirements for different attention types"""
head_dim = d_model // n_heads
# Multi-Head Attention (all heads in KV cache)
mha_kv = n_heads * head_dim * 2 # bytes per position
# Grouped Query Attention (8 KV heads)
gqa_kv = 8 * head_dim * 2
# Multi-Query Attention (1 KV head)
mqa_kv = 1 * head_dim * 2
# Memory for full context
mha_total = mha_kv * seq_len
gqa_total = gqa_kv * seq_len
mqa_total = mqa_kv * seq_len
print(f"Sequence length: {seq_len}")
print(f"MHA: {mha_total/1e9:.2f} GB")
print(f"GQA: {gqa_total/1e9:.2f} GB ({mha_total/gqa_total:.1f}x reduction)")
print(f"MQA: {mqa_total/1e9:.2f} GB ({mha_total/mqa_total:.1f}x reduction)")
# Example: 4096 sequence, 32 heads, head_dim 128
compare_attention_memory(4096, 4096, 32)
# MHA: 4.00 GB
# GQA: 1.00 GB (4.0x reduction)
# MQA: 0.25 GB (16.0x reduction)
Failure Mode: Incorrect Group Expansion
# BUG: Incorrect dimension for repeat_interleave
class BrokenGQA(nn.Module):
def __init__(self, n_query=32, n_kv=8):
super().__init__()
self.n_query = n_query
self.n_kv = n_kv
# ...
def forward(self, q, k, v):
# q: (B, T, 32, H), k,v: (B, T, 8, H)
# WRONG: repeat on wrong dimension
k_expanded = k.repeat_interleave(self.n_query, dim=2)
# Dimension 2 goes from 8 to 32*8=256 instead of 32
# Should repeat on dim=1 (heads) not dim=2 (sequence)
# Or better: use repeat_interleave(q_groups, dim=2)
EXERCISE
Implement GQA with dynamic KV head count (configurable at init). Benchmark memory and throughput for configurations: (32Q, 32K), (32Q, 8K), (32Q, 4K), (32Q, 1K). Plot the memory-throughput tradeoff curve.