12. PQ Encoding and Search
Product Quantization enables approximate nearest neighbor search at massive scale by avoiding full vector comparisons. Instead of comparing against original vectors, you compare compressed codes.
Asymmetric Distance Computation
The standard approach is Asymmetric Distance Computation (ADC): compare query vectors directly against centroid combinations without reconstructing "average" vectors.
For query q and code c = (c₁, c₂, ..., c_m):
d(q, c) = Σᵢ || q_i - centroid(c_i) ||²
This is exact (given the quantized representation) and avoids reconstructing vectors for every database entry.
Precomputed Distance Tables
Computing distances on-the-fly is slow. Precompute lookup tables:
class PQScalarQuantizationSearch:
def __init__(self, pq: ProductQuantizer):
self.pq = pq
self.m = pq.m
self.k = pq.k
def compute_distance_tables(self, query: np.ndarray) -> list[np.ndarray]:
"""
For each subspace, compute distances from query subspace to all centroids.
Returns m tables of shape (k,).
"""
distance_tables = []
for i in range(self.m):
start = i * self.pq.subspace_dim
end = start + self.pq.subspace_dim
query_subspace = query[start:end]
# Distance from query subspace to all k centroids
distances = np.linalg.norm(
query_subspace - pq.codebooks[i],
axis=1
)
distance_tables.append(distances)
return distance_tables
def search(self, codes: np.ndarray, distance_tables: list[np.ndarray],
k: int = 10) -> tuple[np.ndarray, np.ndarray]:
"""
Search using precomputed distance tables.
Returns top-k indices and distances.
"""
n_codes = len(codes)
distances = np.zeros(n_codes, dtype=np.float32)
# Sum distances across all subspaces using table lookup
for i in range(self.m):
distances += distance_tables[i][codes[:, i]]
# Get top-k
top_k_indices = np.argpartition(distances, k)[:k]
top_k_distances = distances[top_k_indices]
# Sort by distance
sorted_order = np.argsort(top_k_distances)
return top_k_indices[sorted_order], top_k_distances[sorted_order]
Handling SIMD in Lookup Tables
Modern CPUs process multiple lookups simultaneously. With AVX-256, you can process 8 float32 values per instruction. Structure your distance accumulation to benefit:
# SIMD-friendly distance accumulation
def search_batch_simd(codes: np.ndarray, distance_tables: list[np.ndarray],
k: int = 10, batch_size: int = 1024) -> tuple[np.ndarray, np.ndarray]:
"""
Batch search optimized for SIMD processing.
Processes codes in batches for better cache utilization.
"""
n_codes = len(codes)
results = []
for batch_start in range(0, n_codes, batch_size):
batch_end = min(batch_start + batch_size, n_codes)
batch_codes = codes[batch_start:batch_end]
# Accumulate distances subspace by subspace
batch_distances = np.zeros(len(batch_codes), dtype=np.float32)
for i, table in enumerate(distance_tables):
# Table lookup is O(1), SIMD-friendly
batch_distances += table[batch_codes[:, i]]
# Get batch top-k
batch_top_k = np.argpartition(batch_distances, k)[:k]
results.append((batch_start + batch_top_k, batch_distances[batch_top_k]))
# Merge results from all batches
all_indices = np.concatenate([r[0] for r in results])
all_distances = np.concatenate([r[1] for r in results])
global_top_k = np.argpartition(all_distances, k)[:k]
return all_indices[global_top_k], all_distances[global_top_k]
Failure Modes
Cold codebook problem: If your queries come from a different distribution than training data, distances become unreliable. Always validate on representative query samples.
Overflow in distance accumulation: With many subspaces, distances accumulate without bounds. Use float32 accumulation and partition-based selection rather than full sorting to avoid numerical issues.
Implement PQ search and measure query throughput (queries/second) with different batch sizes. Profile cache misses using perf or similar tooling. Find the batch size that maximizes throughput on your hardware.
# Target: find optimal batch size for your CPU architecture
# Profile with: perf stat -e cache-misses,cache-references ./your_script.py