09. Cross-Modal Retrieval
Given a query in one modality, retrieve relevant content in another. This chapter covers retrieval architectures, evaluation metrics, and practical implementation details.
Retrieval Pipeline
Cross-modal retrieval consists of: encoding query, computing similarity against candidate embeddings, and returning top-k results.
class CrossModalRetriever:
def __init__(self, embedder, candidates, candidate_metadata=None):
self.embedder = embedder
self.candidates = candidates
self.metadata = candidate_metadata or [None] * len(candidates)
# Pre-compute candidate embeddings
self.candidate_embeddings = self._compute_candidate_embeddings()
def _compute_candidate_embeddings(self):
"""Encode all candidates upfront for fast retrieval."""
embeddings = []
batch_size = 32
for i in range(0, len(self.candidates), batch_size):
batch = self.candidates[i:i+batch_size]
# Detect modality and encode appropriately
if isinstance(batch[0], str) and batch[0].endswith(('.jpg', '.png')):
batch_emb = self.embedder.encode_images(batch)
elif isinstance(batch[0], str) and batch[0].endswith('.mp4'):
batch_emb = self.embedder.encode_videos_from_paths(batch)
else:
batch_emb = self.embedder.encode_text(batch)
embeddings.append(batch_emb)
return F.normalize(torch.cat(embeddings), p=2, dim=-1)
def retrieve(self, query, modality="text", top_k=10):
"""Retrieve top-k candidates given a query."""
# Encode query
if modality == "text":
query_emb = self.embedder.encode_text([query])
elif modality == "image":
query_emb = self.embedder.encode_images([query])
elif modality == "video":
query_emb = self.embedder.encode_videos_from_paths([query])
query_emb = F.normalize(query_emb, p=2, dim=-1)
# Compute similarities
similarities = torch.matmul(query_emb, self.candidate_embeddings.T)
# Get top-k
top_scores, top_indices = similarities[0].topk(top_k)
results = []
for score, idx in zip(top_scores, top_indices):
results.append({
"score": score.item(),
"candidate": self.candidates[idx],
"metadata": self.metadata[idx]
})
return results
Evaluation Metrics
Standard retrieval metrics measure precision, recall, and ranking quality.
def evaluate_retrieval(retriever, queries, ground_truth, modality="text"):
"""
queries: list of query items
ground_truth: dict mapping query index to relevant candidate indices
"""
recall_at_k = {1: [], 5: [], 10: []}
mrr = [] # Mean Reciprocal Rank
for i, query in enumerate(queries):
results = retriever.retrieve(query, modality=modality, top_k=10)
retrieved_indices = [r["candidate"] for r in results]
relevant = ground_truth.get(i, set())
# Recall@K
for k in recall_at_k:
if k <= len(retrieved_indices):
retrieved_set = set(retrieved_indices[:k])
recall = len(retrieved_set & relevant) / len(relevant) if relevant else 0
recall_at_k[k].append(recall)
# MRR
for rank, candidate in enumerate(retrieved_indices, 1):
if candidate in relevant:
mrr.append(1 / rank)
break
else:
mrr.append(0)
metrics = {
f"Recall@{k}": np.mean(v) for k, v in recall_at_k.items()
}
metrics["MRR"] = np.mean(mrr)
return metrics
Failure Mode: Dataset Bias
Retrieval models exploit superficial correlations. Models trained on LAION-5B learn that photos with certain colors or compositions match common caption patterns, regardless of actual semantic content.
def diagnose_retrieval_bias(retriever, test_pairs):
"""Diagnose whether model relies on spurious correlations."""
bias_indicators = {
"color_correlation": [],
"object_size": [],
"aspect_ratio": []
}
for img_path, caption in test_pairs:
# Retrieve using text
text_results = retriever.retrieve(caption, modality="text", top_k=5)
retrieved_paths = [r["candidate"] for r in text_results]
# Check if retrieved images share superficial properties with query image
original_img = Image.open(img_path)
original_aspect = original_img.width / original_img.height
for ret_path in retrieved_paths[:1]:
ret_img = Image.open(ret_path)
# Color similarity (should not matter semantically)
orig_colors = np.array(original_img).mean(axis=(0,1))
ret_colors = np.array(ret_img).mean(axis=(0,1))
color_sim = 1 - np.abs(orig_colors - ret_colors).sum() / 255 / 3
bias_indicators["color_correlation"].append(color_sim)
avg_color_bias = np.mean(bias_indicators["color_correlation"])
print(f"Average color similarity in top-1 retrieval: {avg_color_bias:.3f}")
print("(Higher values indicate model may rely on color matching)")
Implement image-to-text retrieval on a small dataset (100 images with captions). Compute Recall@5 and MRR. Then test resilience by corrupting captions (removing adjectives, swapping objects). Report how retrieval performance degrades.