07. Multi-Modal RAG

Chapter 7 of 24 · 20 min

Retrieval-Augmented Generation extends to multi-modal contexts. A query in one modality retrieves relevant content from another, then a generator produces a response grounded in retrieved material.

Architecture Overview

Multi-modal RAG consists of: multi-modal indexing, retrieval, and generation. The index maps content to embeddings searchable across modalities.

from langchain.retrievers import MultiModalRetriever
from langchain_community.vectorstores import Chroma
from transformers import AutoModel, AutoProcessor

class MultiModalRAG:
    def __init__(self, embedder_name="sentence-transformers/clip-ViT-B-32"):
        self.embedder = AutoModel.from_pretrained(embedder_name)
        self.processor = AutoProcessor.from_pretrained(embedder_name)
        
        # Separate indexes for each modality
        self.text_store = Chroma(
            collection_name="text",
            embedding_function=self._text_embedder
        )
        self.image_store = Chroma(
            collection_name="images",
            embedding_function=self._image_embedder
        )
        self.video_store = Chroma(
            collection_name="videos",
            embedding_function=self._video_embedder
        )
    
    def _text_embedder(self, texts):
        inputs = self.processor(text=texts, return_tensors="pt", padding=True)
        with torch.no_grad():
            return self.embedder.get_text_features(**inputs).numpy()
    
    def _image_embedder(self, images):
        inputs = self.processor(text=[""], images=images, return_tensors="pt", padding=True)
        with torch.no_grad():
            return self.embedder.get_image_features(
                images=inputs["pixel_values"]
            ).numpy()
    
    def _video_embedder(self, video_frames):
        """Average frame embeddings for video-level representation."""
        frame_embeddings = []
        for frame in video_frames:
            emb = self._image_embedder([frame])
            frame_embeddings.append(emb)
        return np.mean(frame_embeddings, axis=0)
    
    def index_content(self, text_docs, image_paths, video_paths):
        """Index all modalities."""
        # Index text
        if text_docs:
            self.text_store.add_texts(text_docs)
        
        # Index images
        if image_paths:
            images = [Image.open(p) for p in image_paths]
            self.image_store.add_images(image_paths)
        
        # Index videos
        if video_paths:
            for path in video_paths:
                frames = self._load_frames(path)
                self.video_store.add_texts([path])  # Store path with embedding

Retrieval Across Modalities

The power of multi-modal RAG: query with text, retrieve images. Or query with an image, retrieve relevant video segments.

def retrieve_across_modality(
    self,
    query,
    query_modality,
    target_modality,
    top_k=5
):
    """
    Retrieve content from target_modality using query from query_modality.
    Example: query="cat image", query_modality="text", target="image"
    """
    # Embed query in its native modality
    if query_modality == "text":
        query_embedding = self._text_embedder([query])
    elif query_modality == "image":
        query_embedding = self._image_embedder([query])
    else:
        raise ValueError(f"Unknown modality: {query_modality}")
    
    # Retrieve from target modality store
    target_store = getattr(self, f"{target_modality}_store")
    
    results = target_store.similarity_search_by_vector(
        embedding=query_embedding[0],
        k=top_k
    )
    
    return results

# Example usage
def rag_image_query(self, text_query):
    """Query for images using text description."""
    return self.retrieve_across_modality(
        query=text_query,
        query_modality="text",
        target_modality="image"
    )

Failure Mode: Embedding Space Mismatch

CLIP-style models project images and text into the same space but optimize for retrieval metrics, not semantic equivalence. A query for "happy person" may retrieve images with bright colors rather than smiling faces.

# Demonstrate embedding mismatch
def analyze_retrieval_bottleneck():
    """Analyze why cross-modal retrieval fails."""
    
    # Query: "person running"
    query_emb = model.text_embed(["person running"])
    
    # Two possible images
    img1 = load_image("runner_mountains.jpg")      # Has person, not obvious running
    img2 = load_image("treadmill_blurry.jpg")      # Motion blur suggests running
    
    # Image embeddings
    emb1 = model.image_embed(img1)
    emb2 = model.image_embed(img2)
    
    # Similarities
    sim1 = cosine_sim(query_emb, emb1)
    sim2 = cosine_sim(query_emb, emb2)
    
    # CLIP may prefer image 1 (semantic content)
    # A model trained on action recognition may prefer image 2 (motion cues)
    print(f"CLIP prefers: {'mountain runner' if sim1 > sim2 else 'treadmill'}")
EXERCISE

Build a simple multi-modal RAG system using CLIP for embeddings and ChromaDB for storage. Index a folder of images. Test queries like "find images of cats" versus "find images similar to photo.jpg". Document differences in retrieval behavior.