07. Multi-Modal RAG
Retrieval-Augmented Generation extends to multi-modal contexts. A query in one modality retrieves relevant content from another, then a generator produces a response grounded in retrieved material.
Architecture Overview
Multi-modal RAG consists of: multi-modal indexing, retrieval, and generation. The index maps content to embeddings searchable across modalities.
from langchain.retrievers import MultiModalRetriever
from langchain_community.vectorstores import Chroma
from transformers import AutoModel, AutoProcessor
class MultiModalRAG:
def __init__(self, embedder_name="sentence-transformers/clip-ViT-B-32"):
self.embedder = AutoModel.from_pretrained(embedder_name)
self.processor = AutoProcessor.from_pretrained(embedder_name)
# Separate indexes for each modality
self.text_store = Chroma(
collection_name="text",
embedding_function=self._text_embedder
)
self.image_store = Chroma(
collection_name="images",
embedding_function=self._image_embedder
)
self.video_store = Chroma(
collection_name="videos",
embedding_function=self._video_embedder
)
def _text_embedder(self, texts):
inputs = self.processor(text=texts, return_tensors="pt", padding=True)
with torch.no_grad():
return self.embedder.get_text_features(**inputs).numpy()
def _image_embedder(self, images):
inputs = self.processor(text=[""], images=images, return_tensors="pt", padding=True)
with torch.no_grad():
return self.embedder.get_image_features(
images=inputs["pixel_values"]
).numpy()
def _video_embedder(self, video_frames):
"""Average frame embeddings for video-level representation."""
frame_embeddings = []
for frame in video_frames:
emb = self._image_embedder([frame])
frame_embeddings.append(emb)
return np.mean(frame_embeddings, axis=0)
def index_content(self, text_docs, image_paths, video_paths):
"""Index all modalities."""
# Index text
if text_docs:
self.text_store.add_texts(text_docs)
# Index images
if image_paths:
images = [Image.open(p) for p in image_paths]
self.image_store.add_images(image_paths)
# Index videos
if video_paths:
for path in video_paths:
frames = self._load_frames(path)
self.video_store.add_texts([path]) # Store path with embedding
Retrieval Across Modalities
The power of multi-modal RAG: query with text, retrieve images. Or query with an image, retrieve relevant video segments.
def retrieve_across_modality(
self,
query,
query_modality,
target_modality,
top_k=5
):
"""
Retrieve content from target_modality using query from query_modality.
Example: query="cat image", query_modality="text", target="image"
"""
# Embed query in its native modality
if query_modality == "text":
query_embedding = self._text_embedder([query])
elif query_modality == "image":
query_embedding = self._image_embedder([query])
else:
raise ValueError(f"Unknown modality: {query_modality}")
# Retrieve from target modality store
target_store = getattr(self, f"{target_modality}_store")
results = target_store.similarity_search_by_vector(
embedding=query_embedding[0],
k=top_k
)
return results
# Example usage
def rag_image_query(self, text_query):
"""Query for images using text description."""
return self.retrieve_across_modality(
query=text_query,
query_modality="text",
target_modality="image"
)
Failure Mode: Embedding Space Mismatch
CLIP-style models project images and text into the same space but optimize for retrieval metrics, not semantic equivalence. A query for "happy person" may retrieve images with bright colors rather than smiling faces.
# Demonstrate embedding mismatch
def analyze_retrieval_bottleneck():
"""Analyze why cross-modal retrieval fails."""
# Query: "person running"
query_emb = model.text_embed(["person running"])
# Two possible images
img1 = load_image("runner_mountains.jpg") # Has person, not obvious running
img2 = load_image("treadmill_blurry.jpg") # Motion blur suggests running
# Image embeddings
emb1 = model.image_embed(img1)
emb2 = model.image_embed(img2)
# Similarities
sim1 = cosine_sim(query_emb, emb1)
sim2 = cosine_sim(query_emb, emb2)
# CLIP may prefer image 1 (semantic content)
# A model trained on action recognition may prefer image 2 (motion cues)
print(f"CLIP prefers: {'mountain runner' if sim1 > sim2 else 'treadmill'}")
Build a simple multi-modal RAG system using CLIP for embeddings and ChromaDB for storage. Index a folder of images. Test queries like "find images of cats" versus "find images similar to photo.jpg". Document differences in retrieval behavior.