02. Attention Mechanisms
The attention mechanism computes contextual relationships between all positions in a sequence. For each output position, attention weighs all input positions based on learned relevance.
The scaled dot-product attention receives queries (Q), keys (K), and values (V) as inputs. The attention weights come from Q-K similarity, then apply to V:
Attention(Q, K, V) = softmax(QK^T / √d_k) V
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, heads, seq_len, d_k)
Returns: (batch, heads, seq_len, d_k)
"""
d_k = Q.shape[-1]
# Compute attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# Apply mask (for causal attention in decoder)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax to get weights
attn_weights = F.softmax(scores, dim=-1)
# Apply weights to values
output = torch.matmul(attn_weights, V)
return output, attn_weights
Failure mode: Numerical instability in softmax occurs with large attention scores. Before dividing by √d_k, if QK^T values exceed ~87, you get overflow (exp(87) ≈ inf64). The √d_k scaling prevents this for reasonable d_k values, but with d_model=4096, extreme activations still cause issues.
Failure mode: Naive implementation allocates O(n²) memory for attention scores. For a batch of 8 sequences at 4096 tokens, this is 8 × 4096² × 4 bytes ≈ 0.5 GB just for scores, before values. Larger models or longer contexts make this prohibitive.
The mask parameter handles causal masking: each output position should only attend to previous positions (and itself for encoder-decoder attention). This prevents the model from "looking ahead" during training.
Local verification checkpoint
Run the smallest example from this chapter in a local workspace and record the package version, runtime, data path, and observed output. If the result depends on model size, vector count, CPU/GPU backend, or available memory, note that constraint beside the exercise so the lesson remains reproducible.
Implement causal masking manually and verify that a position at index 100 only attends to indices 0-100 (inclusive). Check with a random attention weight matrix by applying your mask and inspecting which positions receive non-zero weights.