14. Custom Attention Patterns
Chapter 14 of 24 · 25 min
Standard attention patterns rarely match real-world use cases. Custom patterns enable domain-specific optimization.
Sparse Attention Patterns
class SparseAttention(nn.Module):
"""
Implement band + global attention pattern.
Each token attends to:
- Nearby tokens within a band width
- Specific global tokens (e.g., CLS, special markers)
"""
def __init__(self, d_model, n_heads, band_width=128, n_global=2):
super().__init__()
self.band_width = band_width
self.n_global = n_global
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim)
k = k.view(B, T, self.n_heads, self.head_dim)
v = v.view(B, T, self.n_heads, self.head_dim)
# Create sparse attention mask
mask = torch.zeros(T, T, device=x.device)
# Band: attend to nearby positions
for i in range(T):
lo = max(0, i - self.band_width)
hi = min(T, i + self.band_width + 1)
mask[i, lo:hi] = 1.0
# Global: first n_global tokens attend all positions
mask[:, :self.n_global] = 1.0
mask[:self.n_global, :] = 1.0
# Compute attention with mask
scale = self.head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
# Apply mask (set masked positions to -inf before softmax)
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = attn.softmax(dim=-1)
out = attn @ v
return out.transpose(1, 2).reshape(B, T, C)
Block-Sparse Attention for Long Sequences
class BlockSparseAttention(nn.Module):
"""
Chunk-based attention: attend within blocks, cross-block via reduction.
Memory: O((T/B)² + T) instead of O(T²)
"""
def __init__(self, d_model, n_heads, block_size=128, n_global_blocks=2):
super().__init__()
self.block_size = block_size
self.n_global_blocks = n_global_blocks
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
# Global token aggregation for cross-block communication
self.global_proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
n_blocks = (T + self.block_size - 1) // self.block_size
pad_T = n_blocks * self.block_size
pad_x = torch.nn.functional.pad(x, (0, 0, 0, pad_T - T))
# Reshape into blocks: (B, n_blocks, block_size, C)
q = q.view(B, n_blocks, self.block_size, -1)
k = k.view(B, n_blocks, self.block_size, -1)
v = v.view(B, n_blocks, self.block_size, -1)
# Local attention within blocks
attn = torch.einsum('bnid,bnjd->bnij', q, k) / (self.head_dim ** 0.5)
attn = attn.softmax(dim=-1)
local_out = torch.einsum('bnij,bnjd->bnid', attn, v)
# Global summary per block (mean pooling)
global_k = k.mean(dim=2) # (B, n_blocks, H, D)
global_v = v.mean(dim=2)
# Cross-block attention via global summaries
# (Implementation simplified - production code needs iterative refinement)
return local_out.reshape(B, T, C)
Stride Patterns for Vision-Language
class StridedAttention(nn.Module):
"""
Stride-based attention pattern useful for image-like 2D data.
Token at position (i,j) attends to (i±s, j±s) where s is stride.
"""
def __init__(self, d_model, n_heads, stride=2):
super().__init__()
self.stride = stride
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.proj = nn.Linear(d_model, d_model)
def create_stride_mask(self, T):
"""Create mask for stride-based sparse attention"""
mask = torch.zeros(T, T)
for i in range(T):
for j in range(T):
# Stride pattern: |i - j| is multiple of stride
if (abs(i - j)) % self.stride == 0:
mask[i, j] = 1.0
return mask
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim)
k = k.view(B, T, self.n_heads, self.head_dim)
v = v.view(B, T, self.n_heads, self.head_dim)
mask = self.create_stride_mask(T).to(x.device)
scale = self.head_dim ** -0.5
attn = (q @ k.transpose(-2, -1)) * scale
attn = attn.masked_fill(mask == 0, float('-inf'))
attn = attn.softmax(dim=-1)
return (attn @ v).transpose(1, 2).reshape(B, T, C)
Failure Mode: Mask Non-Deterministic Behavior
# BUG: Mask not applied consistently across heads
class BrokenSparseAttn(nn.Module):
def forward(self, x):
q, k, v = self.qkv(x).chunk(3, dim=-1)
B, T, H, D = q.shape
# Apply mask only to first head - silently incorrect
attn = (q @ k.transpose(-2, -1)) / (D ** 0.5)
attn[:, :, :, self.block_size:] = float('-inf') # Only first head
attn = attn.softmax(dim=-1)
return (attn @ v).reshape(B, T, -1)
# Other heads see full attention, defeating sparsity purpose
EXERCISE
Implement a "dilated" attention pattern where token i attends to tokens at positions i, i±s, i±2s, i±4s, etc. (doubling gaps). Verify attention patterns using torch.allclose against a reference implementation.