16. Caching Strategies
Chapter 16 of 22 · 25 min
RAG systems repeatedly retrieve the same or similar content. Caching reduces latency and API costs for both embeddings and LLM responses.
Embedding Cache
Cache embedding results using the text hash as key.
import hashlib
import json
from functools import lru_cache
class EmbeddingCache:
def __init__(self, backend=None, max_size: int = 10000):
self.cache = backend or {}
self.max_size = max_size
self.hits = 0
self.misses = 0
def _hash(self, text: str) -> str:
"""Hash text for cache key."""
return hashlib.sha256(text.encode()).hexdigest()[:32]
def get_embedding(self, text: str) -> list | None:
"""Get cached embedding or return None."""
key = self._hash(text)
if key in self.cache:
self.hits += 1
return self.cache[key]
self.misses += 1
return None
def set_embedding(self, text: str, embedding: list):
"""Cache an embedding."""
key = self._hash(text)
# Simple eviction: remove oldest if full
if len(self.cache) >= self.max_size:
# Remove first item (FIFO, could use LRU)
first_key = next(iter(self.cache))
del self.cache[first_key]
self.cache[key] = embedding
def hit_rate(self) -> float:
"""Return cache hit rate."""
total = self.hits + self.misses
return self.hits / total if total > 0 else 0
Redis Cache for Production
For distributed systems, use Redis with TTL.
import redis
import json
class RedisEmbeddingCache:
def __init__(self, redis_url: str, ttl_seconds: int = 3600):
self.redis = redis.from_url(redis_url)
self.ttl = ttl_seconds
def _hash(self, text: str) -> str:
return hashlib.sha256(text.encode()).hexdigest()
def get(self, text: str) -> list | None:
key = f"embed:{self._hash(text)}"
cached = self.redis.get(key)
if cached:
return json.loads(cached)
return None
def set(self, text: str, embedding: list):
key = f"embed:{self._hash(text)}"
self.redis.setex(key, self.ttl, json.dumps(embedding))
Semantic Cache
Exact-match caching misses semantically equivalent queries. Semantic caching groups similar queries.
class SemanticCache:
def __init__(self, embed_model, threshold: float = 0.95):
self.embed_model = embed_model
self.threshold = threshold
self.cache = {} # embedding -> (query, response, timestamp)
self.embedding_cache = EmbeddingCache()
def get_or_compute(self, query: str, compute_fn):
"""Get cached response or compute and cache."""
# Get query embedding
query_emb = self.embedding_cache.get_embedding(query)
if not query_emb:
query_emb = self.embed_model.encode(query)
self.embedding_cache.set_embedding(query, query_emb)
# Search for similar cached queries
for cached_emb, (cached_query, response, _) in self.cache.items():
similarity = cosine_similarity([query_emb], [cached_emb])[0][0]
if similarity >= self.threshold:
return response, True # (response, cached=True)
# Compute new response
response = compute_fn(query)
# Cache it
self.cache[tuple(query_emb)] = (query, response, time.time())
return response, False
Cache Invalidation
Documents change. Cache invalidation handles updates.
class VersionedCache:
def __init__(self):
self.cache = {}
self.doc_versions = {}
def get_chunk(self, chunk_id: str, version: str) -> str | None:
key = (chunk_id, version)
if key in self.cache:
return self.cache[key]
# Check if document was updated
current_version = self.get_current_version(chunk_id)
if version != current_version:
return None # Stale
# Load from storage
chunk = load_chunk(chunk_id)
self.cache[key] = chunk
return chunk
def invalidate_doc(self, doc_id: str):
"""Invalidate all cached chunks for a document."""
new_version = increment_version(self.doc_versions.get(doc_id, "0"))
self.doc_versions[doc_id] = new_version
# Remove old versions from cache
keys_to_remove = [k for k in self.cache if k[0].startswith(f"{doc_id}:")]
for k in keys_to_remove:
del self.cache[k]
EXERCISE
Implement a caching layer for your RAG system. Measure latency reduction and cache hit rate over 100 repeated queries.