11. Two-Stage Retrieval

Chapter 11 of 24 · 15 min

Two-stage retrieval separates coarse candidate selection from fine-grained ranking. This architecture handles large document collections that cannot be scored with expensive models.

Stage 1: Retrieval produces a candidate set using fast approximate methods. Dense retrieval with HNSW, BM25, or hybrid fusion identifies the top N candidates (typically 50-200).

Stage 2: Reranking applies expensive relevance models to candidate documents. Cross-encoders, learned models, or interaction-based features refine ordering.

from dataclasses import dataclass
from typing import List, Optional

@dataclass
class Document:
    id: str
    content: str
    metadata: dict

@dataclass
class RetrievalResult:
    doc_id: str
    content: str
    first_stage_score: float
    second_stage_score: Optional[float] = None
    metadata: Optional[dict] = None

class TwoStageRetriever:
    def __init__(
        self,
        first_stage_retriever,
        second_stage_ranker=None,
        first_stage_k: int = 100,
        final_k: int = 10
    ):
        self.first_stage = first_stage_retriever
        self.second_stage = second_stage_ranker
        self.first_stage_k = first_stage_k
        self.final_k = final_k
    
    def retrieve(self, query: str) -> List[RetrievalResult]:
        # Stage 1: Fast candidate retrieval
        candidates = self.first_stage.search(query, top_k=self.first_stage_k)
        
        # Stage 2: Precise reranking (if available)
        if self.second_stage:
            candidates = self.second_stage.rerank(query, candidates)
        
        # Format results
        results = []
        for rank, candidate in enumerate(candidates[:self.final_k]):
            results.append(RetrievalResult(
                doc_id=candidate.get('doc_id', str(rank)),
                content=candidate.get('text', candidate.get('content', '')),
                first_stage_score=candidate.get('score', 0),
                second_stage_score=candidate.get('rerank_score'),
                metadata=candidate.get('metadata', {})
            ))
        
        return results

# Example usage
retriever = TwoStageRetriever(
    first_stage_retriever=HybridRetriever(dense_index, bm25_index),
    second_stage_ranker=LocalCrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2'),
    first_stage_k=100,
    final_k=10
)

results = retriever.retrieve("How do I configure OAuth2 authentication?")

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

Local verification checkpoint

Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.

EXERCISE

Implement two-stage retrieval with logging at each stage. Measure what percentage of final top-10 documents come from different first-stage score ranges.