06. KV Cache and VRAM
The KV cache is the hidden memory cost that catches many people off guard. Understanding it lets you correctly estimate VRAM for inference and avoid OOM errors during long conversations.
What is the KV cache:
During autoregressive generation, the model computes attention over all previous tokens. Caching the Key and Value matrices for each layer avoids recomputing them:
# Simplified attention with KV cache
def attention_with_cache(query, key_cache, value_cache, position):
# Compute new Q for current position
query = compute_query(input[position])
# Attention over all cached K/V + new K/V
keys = concatenate([key_cache, compute_key(input[position])])
values = concatenate([value_cache, compute_value(input[position])])
return softmax(query @ keys.T / sqrt(d_k)) @ values
KV cache size calculation:
For a model with L layers, H heads, and head_dim d:
kv_cache_per_token = 2 x L x H x d x bytes_per_param
For Llama 3.1 8B (L=32, H=32, d=128, BF16=2 bytes):
32 layers x 2 (K+V) x 32 heads x 128 head_dim x 2 bytes
= 524,288 bytes per token
� 0.5 MB per token in cache
VRAM for context processing:
If you want to process a 4096-token context without generation, you need:
0.5 MB x 4096 = 2GB for KV cache alone
Plus model weights. This is why long-context models require significant VRAM even for inference.
Batch inference VRAM:
When running multiple requests simultaneously, the KV cache multiplies:
kv_cache_per_sequence = per_token_cache x max_context_length
total_cache = kv_cache_per_sequence x batch_size
A batch of 8 sequences at 4096 max context:
2GB x 8 = 16GB just for KV cache
Practical memory breakdown for 7B Q4 model:
Model weights (Q4_K_M): 4.1 GB
KV cache at 4096 context: 2.0 GB
Activation memory: 0.5 GB
-------------------------------
Minimum for inference: 6.6 GB
Calculate the KV cache size for Mistral 7B (L=32, GQA=8 KV heads, d=128). Then calculate the total VRAM needed to run the model at Q4_K_M while maintaining a 8192 token context for batch_size=4.