Transformer & LLM components

Grouped-Query Attention (GQA)

Grouped-Query Attention (GQA) is a variant of multi-head attention that reduces memory and compute costs by sharing key-value (KV) heads across multiple query heads. In standard multi-head attention, each query head has its own key and value head. GQA groups query heads so that a single KV head serves several query heads, typically in a ratio like 2:1 or 8:1. This cuts the size of the KV cache roughly in half (for a 2:1 ratio) with minimal quality loss, making it practical for longer context windows on consumer GPUs with limited VRAM.

Deeper dive

GQA was introduced in the Gemma and Llama 2 70B models as a compromise between multi-head attention (MHA) and multi-query attention (MQA). MQA uses one KV head for all query heads, saving the most memory but sometimes degrading quality. GQA uses a small number of KV heads (e.g., 8) shared among many query heads (e.g., 32), offering a middle ground. The KV cache size scales linearly with the number of KV heads, so reducing KV heads directly reduces VRAM usage during inference. For operators, this means models using GQA can support longer context windows or run on smaller GPUs without offloading. Llama 3.1 8B uses 8 KV heads with 32 query heads (4:1 ratio), while Llama 2 7B uses full MHA (32 KV heads). The switch to GQA in newer models is a deliberate design choice to improve inference efficiency.

Practical example

A 70B model with full MHA requires ~40 GB for the KV cache at 4K context (assuming FP16). With GQA (8 KV heads instead of 64), the KV cache drops to ~5 GB, fitting comfortably on a 24 GB RTX 4090. Llama 3.1 70B uses GQA with 8 KV heads, enabling 32K context on a single 80 GB A100, whereas a non-GQA 70B model would need far more memory.

Workflow example

When running llama-cli -m llama3.1-8b.Q4_K_M.gguf -c 8192 with llama.cpp, the runtime allocates KV cache based on the model's KV head count. For Llama 3.1 8B (8 KV heads), the KV cache at 8K context is about 1 GB. If you try the same with a non-GQA model like Llama 2 7B (32 KV heads), the cache would be ~4 GB, potentially exceeding VRAM on a 8 GB card. You can inspect KV head count in model metadata via llama.cpp --model-info.

Reviewed by Fredoline Eruo. See our editorial policy.