18. Advanced NLP Pipeline Project

Chapter 18 of 18 · 25 min

This chapter integrates course concepts into a production-ready question answering system. The pipeline combines retrieval, reranking, and generation with proper evaluation and monitoring.

Project: Multi-Hop Question Answering System

from typing import Optional
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
import torch
import torch.nn.functional as F
from tqdm import tqdm

@dataclass
class RetrievedChunk:
    text: str
    score: float
    source: str
    metadata: dict

@dataclass
class QAResult:
    question: str
    answer: str
    confidence: float
    supporting_chunks: list[RetrievedChunk]
    reasoning_chain: list[str]

class MultiHopQASystem:
    def __init__(self, config: dict):
        self.config = config
        
        # Embedding model for retrieval
        self.embedding_tokenizer = AutoTokenizer.from_pretrained(
            config["embedding_model"]
        )
        self.embedding_model = AutoModel.from_pretrained(
            config["embedding_model"]
        )
        
        # Reader/generator model
        self.generator_tokenizer = AutoTokenizer.from_pretrained(
            config["generator_model"]
        )
        self.generator_model = AutoModelForCausalLM.from_pretrained(
            config["generator_model"]
        )
        
        self.device = config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
        self.embedding_model.to(self.device)
        self.generator_model.to(self.device)
        
        # Optional reranker
        if config.get("reranker_model"):
            self.reranker = self._init_reranker(config["reranker_model"])
        else:
            self.reranker = None
    
    def _init_reranker(self, model_name: str):
        from sentence_transformers import CrossEncoder
        return CrossEncoder(model_name)
    
    def _embed(self, texts: list[str]) -> torch.Tensor:
        """Generate embeddings for texts."""
        inputs = self.embedding_tokenizer(
            texts, padding=True, truncation=True, 
            max_length=512, return_tensors="pt"
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.embedding_model(**inputs)
            embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token
        
        # Normalize
        embeddings = F.normalize(embeddings, p=2, dim=-1)
        return embeddings
    
    def retrieve(self, query: str, corpus: list[dict], top_k: int = 10) -> list[RetrievedChunk]:
        """Retrieve relevant chunks from corpus."""
        # Encode query
        query_embedding = self._embed([query])
        
        # Encode documents
        doc_texts = [doc["text"] for doc in corpus]
        doc_embeddings = self._embed(doc_texts)
        
        # Compute similarities
        scores = torch.mm(query_embedding, doc_embeddings.T)[0]
        top_indices = scores.topk(min(top_k * 2, len(corpus)))[1]
        
        chunks = []
        for idx in top_indices:
            chunks.append(RetrievedChunk(
                text=corpus[idx]["text"],
                score=scores[idx].item(),
                source=corpus[idx].get("source", "unknown"),
                metadata=corpus[idx].get("metadata", {})
            ))
        
        return chunks
    
    def rerank(self, query: str, chunks: list[RetrievedChunk], top_k: int = 5) -> list[RetrievedChunk]:
        """Rerank retrieved chunks using cross-encoder."""
        if not self.reranker:
            return chunks[:top_k]
        
        pairs = [(query, chunk.text) for chunk in chunks]
        scores = self.reranker.predict(pairs)
        
        reranked = sorted(
            zip(chunks, scores), 
            key=lambda x: x[1], 
            reverse=True
        )[:top_k]
        
        return [chunk for chunk, score in reranked]
    
    def generate_answer(self, question: str, context_chunks: list[RetrievedChunk]) -> QAResult:
        """Generate answer from retrieved context."""
        # Construct context from chunks
        context = "\n\n".join([
            f"[Source {i+1}] {chunk.text}" 
            for i, chunk in enumerate(context_chunks)
        ])
        
        prompt = f"""Based on the following context, answer the question. 
If the answer cannot be determined from the context, say "I cannot answer this question based on the provided information."

Context:
{context}

Question: {question}

Answer:"""
        
        inputs = self.generator_tokenizer(
            prompt, return_tensors="pt", max_length=2048, truncation=True
        )
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        with torch.no_grad():
            outputs = self.generator_model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.3,
                do_sample=True,
                top_p=0.9,
                repetition_penalty=1.1
            )
        
        answer = self.generator_tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:], 
            skip_special_tokens=True
        )
        
        # Estimate confidence from generation metrics
        confidence = self._estimate_confidence(outputs, inputs)
        
        return QAResult(
            question=question,
            answer=answer.strip(),
            confidence=confidence,
            supporting_chunks=context_chunks[:3],
            reasoning_chain=self._extract_reasoning(answer)
        )
    
    def _estimate_confidence(self, outputs, inputs) -> float:
        """Estimate answer confidence from token probabilities."""
        logits = outputs.logits[:, :-1]
        targets = inputs["input_ids"][:, 1:]
        
        log_probs = F.log_softmax(logits, dim=-1)
        token_log_probs = torch.gather(log_probs, 2, targets.unsqueeze(2)).squeeze()
        
        avg_log_prob = token_log_probs.mean().item()
        confidence = min(1.0, max(0.0, (avg_log_prob + 2) / 2))
        return confidence
    
    def _extract_reasoning(self, answer: str) -> list[str]:
        """Extract reasoning steps from generated answer."""
        # Simple extraction - could be enhanced with NLI or structured output
        sentences = answer.split(". ")
        return [s.strip() for s in sentences if len(s) > 10]
    
    def answer(self, question: str, corpus: list[dict]) -> QAResult:
        """Complete QA pipeline."""
        # Retrieve
        initial_chunks = self.retrieve(question, corpus, top_k=self.config.get("retrieval_top_k", 20))
        
        # Rerank
        if self.reranker:
            reranked_chunks = self.rerank(question, initial_chunks, top_k=self.config.get("final_top_k", 5))
        else:
            reranked_chunks = initial_chunks[:self.config.get("final_top_k", 5)]
        
        # Generate
        return self.generate_answer(question, reranked_chunks)


# Pipeline Evaluation
class QAEvaluator:
    def __init__(self, system: MultiHopQASystem):
        self.system = system
    
    def evaluate(self, test_set: list[dict]) -> dict:
        """Evaluate system on test set."""
        from sklearn.metrics import accuracy, f1
        from collections import Counter
        
        predictions = []
        references = []
        confidences = []
        
        for example in tqdm(test_set, desc="Evaluating"):
            result = self.system.answer(example["question"], example["corpus"])
            
            predictions.append(result.answer)
            references.append(example["answer"])
            confidences.append(result.confidence)
        
        # Compute metrics
        exact_match = sum(
            p.strip().lower() == r.strip().lower() 
            for p, r in zip(predictions, references)
        ) / len(predictions)
        
        partial_match = sum(
            self._fuzzy_match(p, r) 
            for p, r in zip(predictions, references)
        ) / len(predictions)
        
        return {
            "exact_match": exact_match,
            "partial_match": partial_match,
            "avg_confidence": sum(confidences) / len(confidences),
            "predictions": predictions,
            "references": references,
            "confidences": confidences
        }
    
    def _fuzzy_match(self, pred: str, ref: str, threshold: float = 0.5) -> bool:
        """Check if prediction matches reference above threshold."""
        pred_tokens = set(pred.lower().split())
        ref_tokens = set(ref.lower().split())
        
        if not ref_tokens:
            return False
        
        overlap = len(pred_tokens & ref_tokens)
        jaccard = overlap / len(pred_tokens | ref_tokens)
        return jaccard >= threshold

Deployment Configuration

# requirements.txt
transformers>=4.30.0
torch>=2.0.0
sentence-transformers>=2.2.0
scikit-learn>=1.3.0
faiss-cpu>=1.7.4  # For large-scale retrieval
gradio>=3.40.0    # Optional web interface

# run_pipeline.py
import argparse
import json

def main():
    parser = argparse.ArgumentParser(description="Multi-Hop QA System")
    parser.add_argument("--embedding-model", default="sentence-transformers/all-MiniLM-L6-v2")
    parser.add_argument("--generator-model", default="./models/llama-2-13b")
    parser.add_argument("--reranker-model", default=None)
    parser.add_argument("--config", type=str, default="config.json")
    args = parser.parse_args()
    
    config = {
        "embedding_model": args.embedding_model,
        "generator_model": args.generator_model,
        "reranker_model": args.reranker_model,
        "retrieval_top_k": 20,
        "final_top_k": 5
    }
    
    system = MultiHopQASystem(config)
    
    # Example usage
    corpus = [
        {"text": "Python was created by Guido van Rossum in 1991.", "source": "wiki"},
        {"text": "Guido van Rossum worked at Google before Dropbox.", "source": "bio"},
    ]
    
    result = system.answer("Who created Python and where did they work?", corpus)
    print(f"Answer: {result.answer}")
    print(f"Confidence: {result.confidence:.2f}")

if __name__ == "__main__":
    main()

EXERCISE

Extend the multi-hop QA system to handle multi-document reasoning where the answer requires synthesizing information across documents. Implement a reasoning tracker that explains which chunks contributed to each part of the generated answer.