12. PQ Encoding and Search

Chapter 12 of 18 · 25 min

Product Quantization enables approximate nearest neighbor search at massive scale by avoiding full vector comparisons. Instead of comparing against original vectors, you compare compressed codes.

Asymmetric Distance Computation

The standard approach is Asymmetric Distance Computation (ADC): compare query vectors directly against centroid combinations without reconstructing "average" vectors.

For query q and code c = (c₁, c₂, ..., c_m):

d(q, c) = Σᵢ || q_i - centroid(c_i) ||²

This is exact (given the quantized representation) and avoids reconstructing vectors for every database entry.

Precomputed Distance Tables

Computing distances on-the-fly is slow. Precompute lookup tables:

class PQScalarQuantizationSearch:
    def __init__(self, pq: ProductQuantizer):
        self.pq = pq
        self.m = pq.m
        self.k = pq.k
    
    def compute_distance_tables(self, query: np.ndarray) -> list[np.ndarray]:
        """
        For each subspace, compute distances from query subspace to all centroids.
        Returns m tables of shape (k,).
        """
        distance_tables = []
        
        for i in range(self.m):
            start = i * self.pq.subspace_dim
            end = start + self.pq.subspace_dim
            query_subspace = query[start:end]
            
            # Distance from query subspace to all k centroids
            distances = np.linalg.norm(
                query_subspace - pq.codebooks[i],
                axis=1
            )
            distance_tables.append(distances)
        
        return distance_tables
    
    def search(self, codes: np.ndarray, distance_tables: list[np.ndarray], 
               k: int = 10) -> tuple[np.ndarray, np.ndarray]:
        """
        Search using precomputed distance tables.
        Returns top-k indices and distances.
        """
        n_codes = len(codes)
        distances = np.zeros(n_codes, dtype=np.float32)
        
        # Sum distances across all subspaces using table lookup
        for i in range(self.m):
            distances += distance_tables[i][codes[:, i]]
        
        # Get top-k
        top_k_indices = np.argpartition(distances, k)[:k]
        top_k_distances = distances[top_k_indices]
        
        # Sort by distance
        sorted_order = np.argsort(top_k_distances)
        return top_k_indices[sorted_order], top_k_distances[sorted_order]

Handling SIMD in Lookup Tables

Modern CPUs process multiple lookups simultaneously. With AVX-256, you can process 8 float32 values per instruction. Structure your distance accumulation to benefit:

# SIMD-friendly distance accumulation
def search_batch_simd(codes: np.ndarray, distance_tables: list[np.ndarray], 
                       k: int = 10, batch_size: int = 1024) -> tuple[np.ndarray, np.ndarray]:
    """
    Batch search optimized for SIMD processing.
    Processes codes in batches for better cache utilization.
    """
    n_codes = len(codes)
    results = []
    
    for batch_start in range(0, n_codes, batch_size):
        batch_end = min(batch_start + batch_size, n_codes)
        batch_codes = codes[batch_start:batch_end]
        
        # Accumulate distances subspace by subspace
        batch_distances = np.zeros(len(batch_codes), dtype=np.float32)
        
        for i, table in enumerate(distance_tables):
            # Table lookup is O(1), SIMD-friendly
            batch_distances += table[batch_codes[:, i]]
        
        # Get batch top-k
        batch_top_k = np.argpartition(batch_distances, k)[:k]
        results.append((batch_start + batch_top_k, batch_distances[batch_top_k]))
    
    # Merge results from all batches
    all_indices = np.concatenate([r[0] for r in results])
    all_distances = np.concatenate([r[1] for r in results])
    global_top_k = np.argpartition(all_distances, k)[:k]
    
    return all_indices[global_top_k], all_distances[global_top_k]

Failure Modes

Cold codebook problem: If your queries come from a different distribution than training data, distances become unreliable. Always validate on representative query samples.

Overflow in distance accumulation: With many subspaces, distances accumulate without bounds. Use float32 accumulation and partition-based selection rather than full sorting to avoid numerical issues.


EXERCISE

Implement PQ search and measure query throughput (queries/second) with different batch sizes. Profile cache misses using perf or similar tooling. Find the batch size that maximizes throughput on your hardware.

# Target: find optimal batch size for your CPU architecture
# Profile with: perf stat -e cache-misses,cache-references ./your_script.py