08. Embedding Across Modalities
Joint embedding spaces enable cross-modal operations. This chapter covers the architecture and training of models that project different modalities into comparable representations.
Contrastive Learning for Alignment
CLIP and similar models learn joint spaces through contrastive learning. The training signal: matching image-text pairs should have high similarity; mismatched pairs should have low similarity.
import torch
import torch.nn.functional as F
from torch import nn
class ContrastiveLoss(nn.Module):
"""InfoNCE loss for contrastive learning."""
def __init__(self, temperature=0.07):
super().__init__()
self.temperature = temperature
def forward(self, image_embeds, text_embeds):
"""
image_embeds: (B, D) - normalized image embeddings
text_embeds: (B, D) - normalized text embeddings
"""
# Compute similarity matrix
logits = torch.matmul(image_embeds, text_embeds.T) / self.temperature
# Labels: diagonal is positive pairs
labels = torch.arange(len(logits), device=logits.device)
# Symmetric loss: image->text and text->image
loss_i2t = F.cross_entropy(logits, labels)
loss_t2i = F.cross_entropy(logits.T, labels)
return (loss_i2t + loss_t2i) / 2
def compute_accuracy(self, image_embeds, text_embeds):
"""Compute retrieval accuracy as sanity check."""
logits = torch.matmul(image_embeds, text_embeds.T) / self.temperature
preds = logits.argmax(dim=1)
return (preds == torch.arange(len(preds), device=preds.device)).float().mean()
Projection Heads
Raw encoder outputs are not directly comparable. Projection heads transform them into a joint space where alignment is meaningful.
class ProjectionHead(nn.Module):
"""Project modality-specific features to joint embedding space."""
def __init__(self, input_dim, output_dim=512, hidden_dim=2048):
super().__init__()
self.projection = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim),
)
# Temperature and learnable scaling (used in some architectures)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward(self, features):
return self.projection(features)
def normalize(self, projected):
return F.normalize(projected, p=2, dim=-1)
Handling Modality Imbalance
Different modalities have different information content and entropy. Text is discrete and high-bandwidth; audio is continuous and redundant. This imbalance affects embedding space structure.
def analyze_modality_distribution(image_embeds, text_embeds):
"""Compare embedding distributions across modalities."""
# Compute statistics
stats = {
"image": {
"mean_norm": image_embeds.norm(dim=-1).mean().item(),
"std_norm": image_embeds.norm(dim=-1).std().item(),
"cos_sim_mean": F.cosine_similarity(image_embeds[:, None], image_embeds, dim=-1).mean().item(),
},
"text": {
"mean_norm": text_embeds.norm(dim=-1).mean().item(),
"std_norm": text_embeds.norm(dim=-1).std().item(),
"cos_sim_mean": F.cosine_similarity(text_embeds[:, None], text_embeds, dim=-1).mean().item(),
}
}
# Cross-modal similarity should be lower than within-modality
cross_sim = F.cosine_similarity(image_embeds, text_embeds).mean().item()
print(f"Within-modality similarity:")
print(f" Image: {stats['image']['cos_sim_mean']:.3f}")
print(f" Text: {stats['text']['cos_sim_mean']:.3f}")
print(f"Cross-modal similarity: {cross_sim:.3f}")
return stats
Train a simple CLIP-style model on 1000 image-caption pairs using a pre-trained image encoder and random text encoder. Plot the training loss curve. After training, evaluate zero-shot retrieval accuracy on a held-out test set. Report findings.