20. A/B Testing Retrieval
Chapter 20 of 22 · 20 min
Different retrieval strategies perform differently across query types. A/B testing lets you compare strategies in production with real traffic.
Experiment Configuration
from dataclasses import dataclass
from typing import List, Callable
import random
import hashlib
@dataclass
class Experiment:
name: str
variants: List[str] # e.g., ["baseline", "reranked", "compressed"]
weights: List[float] # e.g., [0.33, 0.33, 0.34]
class ABTester:
def __init__(self, experiment: Experiment):
self.experiment = experiment
self.results = {v: [] for v in experiment.variants}
def get_variant(self, user_id: str) -> str:
"""Determine which variant a user sees."""
# Consistent assignment based on user_id hash
hash_val = int(hashlib.md5(user_id.encode()).hexdigest(), 16)
bucket = hash_val % 100
# Assign to variant based on weights
cumulative = 0
for variant, weight in zip(self.experiment.variants,
self.experiment.weights):
cumulative += weight * 100
if bucket < cumulative:
return variant
return self.experiment.variants[-1]
def record_result(self, variant: str, metrics: dict):
"""Record outcome for a variant."""
self.results[variant].append(metrics)
def get_stats(self) -> dict:
"""Get statistics for each variant."""
stats = {}
for variant, results in self.results.items():
if not results:
stats[variant] = {"count": 0}
continue
precision_scores = [r["precision"] for r in results
if "precision" in r]
stats[variant] = {
"count": len(results),
"mean_precision": np.mean(precision_scores) if precision_scores else 0,
"std_precision": np.std(precision_scores) if precision_scores else 0
}
return stats
Statistical Significance
from scipy import stats
def check_significance(results_a: list, results_b: list,
metric: str = "precision") -> dict:
"""Check if difference between variants is statistically significant."""
values_a = [r[metric] for r in results_a if metric in r]
values_b = [r[metric] for r in results_b if metric in r]
# T-test
t_stat, p_value = stats.ttest_ind(values_a, values_b)
# Effect size (Cohen's d)
pooled_std = np.sqrt((np.var(values_a) + np.var(values_b)) / 2)
effect_size = (np.mean(values_a) - np.mean(values_b)) / pooled_std
return {
"t_statistic": t_stat,
"p_value": p_value,
"effect_size": effect_size,
"significant": p_value < 0.05,
"mean_a": np.mean(values_a),
"mean_b": np.mean(values_b)
}
Gradual Rollout
class GradualRollout:
def __init__(self, experiment: Experiment):
self.experiment = experiment
self.current_weights = {v: 1.0 / len(experiment.variants)
for v in experiment.variants}
def adjust_weights(self, stats: dict):
"""Adjust weights based on performance."""
# Increase weight for better-performing variants
precision_scores = {v: stats[v].get("mean_precision", 0)
for v in self.experiment.variants}
# Softmax to get new weights
scores = list(precision_scores.values())
exp_scores = np.exp(scores - np.max(scores))
new_weights = exp_scores / exp_scores.sum()
# Gradual adjustment (10% shift per update)
for i, variant in enumerate(self.experiment.variants):
self.current_weights[variant] = (
0.9 * self.current_weights[variant] +
0.1 * new_weights[i]
)
def get_variant(self, user_id: str, rollout_percentage: float) -> str:
"""Get variant with gradual rollout control."""
if random.random() > rollout_percentage:
return "baseline" # Holdout
return self._select_by_weight(user_id)
def _select_by_weight(self, user_id: str) -> str:
hash_val = int(hashlib.md5(user_id.encode()).hexdigest(), 16) % 100
cumulative = 0
for variant, weight in self.current_weights.items():
cumulative += weight * 100
if hash_val < cumulative:
return variant
return "baseline"
EXERCISE
Set up an A/B test comparing baseline retrieval with reranking. Run for at least 100 queries per variant, then check for statistical significance.