10. Query-Focused Summarization

Chapter 10 of 18 · 15 min

Query-focused summarization generates summaries conditioned on a specific question or information need. Unlike generic summarization, this task requires identifying and extracting content relevant to the query while filtering out tangential information.

The Core Architecture

Query-focused summarization typically combines extractive and abstractive approaches. The extractive phase selects relevant sentences or passages using relevance scoring against the query. The abstractive phase then synthesizes these fragments into coherent, query-aligned output.

from transformers import AutoTokenizer, AutoModel
import torch

class QueryFocusedSummarizer:
    def __init__(self, model_name="meta-llama/Llama-2-7b-chat-hf"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)
    
    def compute_relevance_scores(self, query: str, sentences: list[str]) -> list[float]:
        """Score each sentence's relevance to the query."""
        query_embedding = self._embed_text(query)
        scores = []
        for sentence in sentences:
            sentence_embedding = self._embed_text(sentence)
            similarity = torch.nn.functional.cosine_similarity(
                query_embedding, sentence_embedding, dim=-1
            )
            scores.append(similarity.item())
        return scores
    
    def _embed_text(self, text: str) -> torch.Tensor:
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.last_hidden_state[:, 0, :]  # CLS token
    
    def generate_summary(self, query: str, document: str, max_sentences: int = 3) -> str:
        sentences = document.split(". ")
        scores = self.compute_relevance_scores(query, sentences)
        
        # Select top-scoring sentences while preserving order
        sentence_scores = list(zip(range(len(sentences)), sentences, scores))
        top_sentences = sorted(sentence_scores, key=lambda x: x[2], reverse=True)[:max_sentences]
        selected_indices = [s[0] for s in sorted(top_sentences)]
        
        relevant_sentences = [sentences[i] for i in selected_indices]
        
        # Refine with query-conditioned generation
        prompt = f"Query: {query}\n\nRelevant information: {' '.join(relevant_sentences)}\n\nGenerate a concise answer:"
        return self._generate_text(prompt)

Reranking for Precision

After initial retrieval, cross-encoders provide more accurate relevance assessment than bi-encoder similarity:

from sentence_transformers import CrossEncoder

class RerankedSummarizer(QueryFocusedSummarizer):
    def __init__(self, model_name="meta-llama/Llama-2-7b-chat-hf"):
        super().__init__(model_name)
        self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
    
    def rerank_sentences(self, query: str, sentences: list[str]) -> list[tuple[int, str, float]]:
        pairs = [(query, sentence) for sentence in sentences]
        scores = self.reranker.predict(pairs)
        return sorted(zip(range(len(sentences)), sentences, scores),
                      key=lambda x: x[2], reverse=True)
EXERCISE

Implement a multi-document query-focused summarizer that aggregates information across sources. Handle conflicting information by presenting competing claims distinctly rather than averaging them.