10. Query-Focused Summarization
Query-focused summarization generates summaries conditioned on a specific question or information need. Unlike generic summarization, this task requires identifying and extracting content relevant to the query while filtering out tangential information.
The Core Architecture
Query-focused summarization typically combines extractive and abstractive approaches. The extractive phase selects relevant sentences or passages using relevance scoring against the query. The abstractive phase then synthesizes these fragments into coherent, query-aligned output.
from transformers import AutoTokenizer, AutoModel
import torch
class QueryFocusedSummarizer:
def __init__(self, model_name="meta-llama/Llama-2-7b-chat-hf"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def compute_relevance_scores(self, query: str, sentences: list[str]) -> list[float]:
"""Score each sentence's relevance to the query."""
query_embedding = self._embed_text(query)
scores = []
for sentence in sentences:
sentence_embedding = self._embed_text(sentence)
similarity = torch.nn.functional.cosine_similarity(
query_embedding, sentence_embedding, dim=-1
)
scores.append(similarity.item())
return scores
def _embed_text(self, text: str) -> torch.Tensor:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state[:, 0, :] # CLS token
def generate_summary(self, query: str, document: str, max_sentences: int = 3) -> str:
sentences = document.split(". ")
scores = self.compute_relevance_scores(query, sentences)
# Select top-scoring sentences while preserving order
sentence_scores = list(zip(range(len(sentences)), sentences, scores))
top_sentences = sorted(sentence_scores, key=lambda x: x[2], reverse=True)[:max_sentences]
selected_indices = [s[0] for s in sorted(top_sentences)]
relevant_sentences = [sentences[i] for i in selected_indices]
# Refine with query-conditioned generation
prompt = f"Query: {query}\n\nRelevant information: {' '.join(relevant_sentences)}\n\nGenerate a concise answer:"
return self._generate_text(prompt)
Reranking for Precision
After initial retrieval, cross-encoders provide more accurate relevance assessment than bi-encoder similarity:
from sentence_transformers import CrossEncoder
class RerankedSummarizer(QueryFocusedSummarizer):
def __init__(self, model_name="meta-llama/Llama-2-7b-chat-hf"):
super().__init__(model_name)
self.reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
def rerank_sentences(self, query: str, sentences: list[str]) -> list[tuple[int, str, float]]:
pairs = [(query, sentence) for sentence in sentences]
scores = self.reranker.predict(pairs)
return sorted(zip(range(len(sentences)), sentences, scores),
key=lambda x: x[2], reverse=True)
Implement a multi-document query-focused summarizer that aggregates information across sources. Handle conflicting information by presenting competing claims distinctly rather than averaging them.