16. Prompt Caching
Chapter 16 of 18 · 20 min
Identical prompt prefixes across multiple requests represent wasted computation. Prompt caching identifies these shared prefixes and reuses computed KV cache across requests.
Common caching opportunities:
- System prompts (same for all requests)
- Instruction templates (same structure, different content)
- Document prefixes (same document, different queries)
- Conversation history (shared across turns)
Implementation in vLLM:
from vllm import LLM, SamplingParams
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
enable_prefix_caching=True, # Enable prompt caching
)
# Requests with identical system prompt share KV cache
requests = [
{
"prompt": "You are a helpful coding assistant. " # Shared prefix
"Explain this function:", # Unique suffix
"prompt_token_ids": [...] # Can pass pre-tokenized to ensure exact match
},
{
"prompt": "You are a helpful coding assistant. " # Identical prefix!
"Debug this code:", # Different suffix
},
]
# First request computes KV for "You are a helpful..."
# Second request reuses computed KV
# ~30% speedup from prefix sharing
Token-level caching beyond prefixes:
# Sub-prefix caching - cache partial prefixes
def sub_prefix_caching(model, prompts):
"""Find shared sub-prefixes across prompts"""
prefixes = {}
for prompt in prompts:
tokens = model.tokenizer.encode(prompt)
# Find longest common prefix
for i in range(1, len(tokens)):
prefix = tuple(tokens[:i])
if prefix not in prefixes:
prefixes[prefix] = []
prefixes[prefix].append(prompt)
# Return common prefixes with their usage count
return {k: len(v) for k, v in prefixes.items() if len(v) > 1}
Memory vs speed tradeoff:
# Aggressive caching - more memory, faster inference
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
enable_prefix_caching=True,
block_size=64, # Larger blocks = more sharing potential
gpu_memory_utilization=0.80, # Reserve memory for cache
)
# Conservative caching - less memory, better concurrency
llm = LLM(
model="meta-llama/Llama-2-70b-hf",
enable_prefix_caching=True,
block_size=16, # Smaller blocks = less fragmentation
gpu_memory_utilization=0.70, # More room for concurrent requests
)
Verification and monitoring:
# Monitor cache hit rates
import requests
metrics = requests.get("http://localhost:8000/metrics").text
for line in metrics.split('\n'):
if 'prefix_cache' in line:
print(line)
# Expected metrics:
# vllm:prefix_cache_hit_ratio - percentage of tokens served from cache
# vllm:prefix_cache_num_hits - absolute cache hit count
# vllm:prefix_cache_num_queries - total queries (hits + misses)
EXERCISE
Analyze your typical workload for prefix sharing opportunities. Implement prompt caching and measure throughput improvement for workloads with high vs low prefix overlap.