03. Multi-Head Attention

Chapter 3 of 24 · 15 min

Multi-head attention runs multiple attention operations in parallel, allowing the model to attend to different representation subspaces simultaneously. Each head learns distinct attention patterns.

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.d_model = d_model
        
        # Project input to Q, K, V
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        self.dropout = nn.Dropout(dropout)
    
    def split_heads(self, x, batch_size):
        # (batch, seq_len, d_model) -> (batch, heads, seq_len, d_k)
        x = x.view(batch_size, -1, self.n_heads, self.d_k)
        return x.transpose(1, 2)
    
    def forward(self, x, mask=None):
        batch_size = x.shape[0]
        
        # Linear projections and split heads
        Q = self.split_heads(self.W_q(x), batch_size)
        K = self.split_heads(self.W_k(x), batch_size)
        V = self.split_heads(self.W_v(x), batch_size)
        
        # Attention
        attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # Concatenate heads: (batch, heads, seq, d_k) -> (batch, seq, d_model)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        
        # Final linear projection
        output = self.W_o(attn_output)
        
        return output

Common implementation bug: Forgetting .contiguous() after transpose causes memory layout issues. The transpose creates a view with non-contiguous memory, and subsequent operations may fail silently or run slowly. Always call .contiguous() before .view().

Failure mode: Head dimension must divide evenly. Mismatched dimensions cause cryptic errors during the view operation. Llama 3.1 uses 8 heads for d_model=4096, which divides cleanly—but custom configurations often don't.

Failure mode: Inattention to dropout placement. Dropout after attention weights diverges training from inference behavior. Standard practice places dropout on the output of attention, not on attention weights themselves (which would be masked anyway during inference).

EXERCISE

Add a causal mask to the MultiHeadAttention implementation. Create a lower-triangular mask of shape (seq_len, seq_len) and pass it to scaled_dot_product_attention. Verify that attended positions follow the causal constraint.