20. A/B Testing Retrieval

Chapter 20 of 22 · 20 min

Different retrieval strategies perform differently across query types. A/B testing lets you compare strategies in production with real traffic.

Experiment Configuration

from dataclasses import dataclass
from typing import List, Callable
import random
import hashlib

@dataclass
class Experiment:
    name: str
    variants: List[str]  # e.g., ["baseline", "reranked", "compressed"]
    weights: List[float]  # e.g., [0.33, 0.33, 0.34]
    
class ABTester:
    def __init__(self, experiment: Experiment):
        self.experiment = experiment
        self.results = {v: [] for v in experiment.variants}
    
    def get_variant(self, user_id: str) -> str:
        """Determine which variant a user sees."""
        
        # Consistent assignment based on user_id hash
        hash_val = int(hashlib.md5(user_id.encode()).hexdigest(), 16)
        bucket = hash_val % 100
        
        # Assign to variant based on weights
        cumulative = 0
        for variant, weight in zip(self.experiment.variants, 
                                   self.experiment.weights):
            cumulative += weight * 100
            if bucket < cumulative:
                return variant
        
        return self.experiment.variants[-1]
    
    def record_result(self, variant: str, metrics: dict):
        """Record outcome for a variant."""
        self.results[variant].append(metrics)
    
    def get_stats(self) -> dict:
        """Get statistics for each variant."""
        
        stats = {}
        
        for variant, results in self.results.items():
            if not results:
                stats[variant] = {"count": 0}
                continue
            
            precision_scores = [r["precision"] for r in results 
                              if "precision" in r]
            
            stats[variant] = {
                "count": len(results),
                "mean_precision": np.mean(precision_scores) if precision_scores else 0,
                "std_precision": np.std(precision_scores) if precision_scores else 0
            }
        
        return stats

Statistical Significance

from scipy import stats

def check_significance(results_a: list, results_b: list, 
                       metric: str = "precision") -> dict:
    """Check if difference between variants is statistically significant."""
    
    values_a = [r[metric] for r in results_a if metric in r]
    values_b = [r[metric] for r in results_b if metric in r]
    
    # T-test
    t_stat, p_value = stats.ttest_ind(values_a, values_b)
    
    # Effect size (Cohen's d)
    pooled_std = np.sqrt((np.var(values_a) + np.var(values_b)) / 2)
    effect_size = (np.mean(values_a) - np.mean(values_b)) / pooled_std
    
    return {
        "t_statistic": t_stat,
        "p_value": p_value,
        "effect_size": effect_size,
        "significant": p_value < 0.05,
        "mean_a": np.mean(values_a),
        "mean_b": np.mean(values_b)
    }

Gradual Rollout

class GradualRollout:
    def __init__(self, experiment: Experiment):
        self.experiment = experiment
        self.current_weights = {v: 1.0 / len(experiment.variants) 
                               for v in experiment.variants}
    
    def adjust_weights(self, stats: dict):
        """Adjust weights based on performance."""
        
        # Increase weight for better-performing variants
        precision_scores = {v: stats[v].get("mean_precision", 0) 
                          for v in self.experiment.variants}
        
        # Softmax to get new weights
        scores = list(precision_scores.values())
        exp_scores = np.exp(scores - np.max(scores))
        new_weights = exp_scores / exp_scores.sum()
        
        # Gradual adjustment (10% shift per update)
        for i, variant in enumerate(self.experiment.variants):
            self.current_weights[variant] = (
                0.9 * self.current_weights[variant] + 
                0.1 * new_weights[i]
            )
    
    def get_variant(self, user_id: str, rollout_percentage: float) -> str:
        """Get variant with gradual rollout control."""
        
        if random.random() > rollout_percentage:
            return "baseline"  # Holdout
        
        return self._select_by_weight(user_id)
    
    def _select_by_weight(self, user_id: str) -> str:
        hash_val = int(hashlib.md5(user_id.encode()).hexdigest(), 16) % 100
        
        cumulative = 0
        for variant, weight in self.current_weights.items():
            cumulative += weight * 100
            if hash_val < cumulative:
                return variant
        
        return "baseline"
EXERCISE

Set up an A/B test comparing baseline retrieval with reranking. Run for at least 100 queries per variant, then check for statistical significance.