15. KV Cache Optimization
The KV cache stores attention keys and values for all processed tokens. During autoregressive generation, this cache grows with every new token. Optimizing KV cache determines maximum context length and throughput.
Memory calculation:
def kv_cache_memory(layer_count, heads, head_dim, batch_size, max_seq_len, bytes_per_param=2):
"""Calculate KV cache memory in GB"""
# For each layer: 2 matrices (K and V), each of shape (batch, heads, seq, head_dim)
bytes_per_token = 2 * layer_count * heads * head_dim * bytes_per_param
total_bytes = bytes_per_token * max_seq_len * batch_size
return total_bytes / (1024 ** 3)
# Example: 70B model, 80 layers, 8 heads, 128 head_dim, batch=1, 32K context
memory = kv_cache_memory(
layer_count=80,
heads=8, # GQA: 8 kv heads, not 80
head_dim=128,
batch_size=1,
max_seq_len=32768,
bytes_per_param=2 # FP16
)
print(f"KV cache memory: {memory:.1f} GB")
# Output: ~16 GB
Grouped Query Attention (GQA) reduces KV cache proportionally to the ratio of kv_heads to query_heads. A 70B model with 8 kv_heads and 64 query_heads requires 1/8 the KV cache memory of standard multi-head attention.
# Verify GQA configuration
from transformers import AutoConfig
config = AutoConfig.from_pretrained("meta-llama/Llama-2-70b-hf")
print(f"Query heads: {config.num_attention_heads}") # 64
print(f"KV heads: {config.num_key_value_heads}") # 8
print(f"GQA ratio: {config.num_attention_heads / config.num_key_value_heads}x") # 8x
KV cache quantization reduces memory by 50-75% with minimal accuracy impact:
# vLLM KV cache quantization
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
kv_cache_dtype="fp8", # Quantize KV cache to FP8
gpu_memory_utilization=0.95,
)
# For more aggressive compression
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
kv_cache_dtype="int8", # INT8 quantization
enforcement_common_spec=False, # Allow non-standard quantization
)
Eviction policies determine which tokens to retain when memory is full:
# Common eviction strategies
eviction_policies = {
"lru": "Least recently used - remove oldest tokens",
"sliding_window": "Keep only recent N tokens",
"sink": "Keep sink tokens + recent tokens",
"random": "Random eviction (simpler, slightly worse)",
}
# vLLM implements sliding window + sink automatically
llm = LLM(
model="mistralai/Mistral-7B-v0.1", # Has sliding window attention
block_size=16,
enable_prefix_caching=True,
)
KV cache compression research explores more aggressive techniques:
# HyperAttention - approximate KV cache with LSH
# Not production-ready but demonstrates future direction
class HyperAttentionCache:
def __init__(self, k=16, hash_size=128):
self.k = k # LSH buckets
self.hash_size = hash_size
def compress(self, keys, values):
# Hash keys into buckets
bucket_ids = self._hash_keys(keys)
# Average within buckets
compressed_k = self._average_buckets(keys, bucket_ids)
compressed_v = self._average_buckets(values, bucket_ids)
return compressed_k, compressed_v
Profile KV cache memory usage at various context lengths. Identify the context length where KV cache exceeds 50% of available VRAM. Test whether quantization enables longer contexts without OOM.