14. Attention Sink
Long-context inference reveals a surprising behavior: models allocate disproportionate attention to early tokens, even when those tokens contain no task-relevant information. This "sink" phenomenon occurs because autoregressive models struggle to maintain stable normalization across long sequences.
Attention sinks form because softmax attention must sum to 1 across all positions. Early tokens (often special tokens like <|endoftext|>) accumulate small but consistent attention weight across generation. While individually small, this allocation compounds over thousands of tokens.
StreamingLLM identifies this pattern and proposes a fix: explicitly reserve attention for a small number of sink tokens. This maintains model stability without dedicating attention capacity to uninformative positions.
# StreamingLLM attention sink implementation
import torch
import torch.nn.functional as F
def streaming_attention(
Q, K, V,
sink_positions, # Positions to retain as sinks
local_window_size, # Recent tokens for local attention
past_key_values=None
):
seq_len = Q.size(1)
# Sink attention (always included)
sink_keys = K[:, sink_positions]
sink_values = V[:, sink_positions]
sink_scores = torch.matmul(Q, sink_keys.transpose(-2, -1)) / math.sqrt(Q.size(-1))
# Local window attention (recent tokens only)
local_K = K[:, -local_window_size:]
local_V = V[:, -local_window_size:]
local_scores = torch.matmul(Q, local_K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
# Combine scores
all_scores = torch.cat([sink_scores, local_scores], dim=-1)
all_weights = F.softmax(all_scores, dim=-1)
# Split weights back to compute attention output
sink_weights = all_weights[:, :len(sink_positions)]
local_weights = all_weights[:, len(sink_positions):]
sink_output = torch.matmul(sink_weights, sink_values)
local_output = torch.matmul(local_weights, local_V)
return sink_output + local_output
Typical configuration:
# StreamingLLM configuration for 7B model
config = {
"sink_tokens": 4, # Keep 4 sink tokens
"local_window": 511, # Attention over last 511 tokens
"total_context": 515, # 4 sinks + 511 local = 515 effective context
}
Memory savings are significant. A 70B model at 32K context with standard attention requires ~128GB for KV cache. StreamingLLM reduces this to ~4GB (4 sinks + 511 window), enabling inference on consumer GPUs.
Accuracy implications:
# Evaluate StreamingLLM accuracy vs full attention
def evaluate_streaming_accuracy(
model,
test_dataset,
sink_tokens=4,
local_window=511
):
results = []
for item in test_dataset:
# Run with full attention
full_output = model.generate(item["prompt"], max_new_tokens=128)
# Run with StreamingLLM
streaming_output = model.generate(
item["prompt"],
max_new_tokens=128,
attention_impl="streaming_llm",
sink_tokens=sink_tokens,
local_window=local_window
)
results.append({
"full": full_output,
"streaming": streaming_output,
"match": full_output == streaming_output,
})
accuracy = sum(r["match"] for r in results) / len(results)
return accuracy
# Typical results: 85-95% exact match, near-perfect semantic similarity
Implement StreamingLLM for a model you're using. Compare generation quality on long-document tasks (10K+ tokens) between full attention and StreamingLLM.