16. Sliding Window Attention
Chapter 16 of 24 · 25 min
Sliding window attention restricts each token's attention span to a local window. This pattern matches the observation that most linguistic dependencies are local—making it efficient for long documents.
Standard Sliding Window Implementation
class SlidingWindowAttention(nn.Module):
"""
Sliding window attention with configurable window size.
Each token attends to window_size tokens before and after.
Complexity: O(window_size * T) instead of O(T²)
"""
def __init__(self, d_model, n_heads, window_size=512):
super().__init__()
self.window_size = window_size
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.o_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim)
k = k.view(B, T, self.n_heads, self.head_dim)
v = v.view(B, T, self.n_heads, self.head_dim)
# Create causal + sliding window mask
# For each position i, allow positions [i-window, i+window]
# With causal, positions > i are already masked
scale = self.head_dim ** -0.5
output = torch.zeros_like(q)
# Process in chunks to avoid O(T²) memory
for start in range(0, T, self.window_size):
end = min(start + self.window_size, T)
# For queries in this chunk, compute attention
q_chunk = q[:, start:end] # (B, w, H, D)
# Keys in window around this chunk
kv_start = max(0, start - self.window_size)
kv_end = min(T, end + self.window_size)
k_window = k[:, kv_start:kv_end]
v_window = v[:, kv_start:kv_end]
# Compute attention for this chunk
attn = torch.einsum('bqhd,bkhd->bhqk', q_chunk, k_window) * scale
# Causal mask for this local window
# Relative positions within the chunk
causal = torch.tril(torch.ones(end - start, kv_end - kv_start, device=x.device))
attn = attn.masked_fill(causal == 0, float('-inf'))
attn = F.softmax(attn, dim=-1)
chunk_out = torch.einsum('bhqk,bkhd->bqhd', attn, v_window)
output[:, start:end] = chunk_out
return output.reshape(B, T, C)
Dilated Sliding Window (Mistral-style)
class DilatedSlidingWindowAttention(nn.Module):
"""
Dilated sliding window - like dilated convolution but for attention.
Layer 0: window_size=16
Layer 1: window_size=16, dilation=2 (covers 0, 2, 4, 6, ...)
Layer 2: window_size=16, dilation=4
Layer 3: window_size=16, dilation=8
Result: receptive field of 4096 with only O(T * window_size) compute
"""
def __init__(self, d_model, n_heads, window_size=16, dilation=1):
super().__init__()
self.window_size = window_size
self.dilation = dilation
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.o_proj = nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim)
k = k.view(B, T, self.n_heads, self.head_dim)
v = v.view(B, T, self.n_heads, self.head_dim)
output = torch.zeros_like(q)
scale = self.head_dim ** -0.5
for start in range(0, T, self.window_size):
end = min(start + self.window_size, T)
q_chunk = q[:, start:end]
# Dilation means we skip positions
# Window covers positions: start, start+dilation, start+2*dilation, ...
kv_start = max(0, start - self.dilation * self.window_size)
kv_end = min(T, end)
# Select with dilation step
indices = torch.arange(kv_start, kv_end, self.dilation, device=x.device)
indices = indices.clamp(0, T - 1)
k_window = k[:, indices]
v_window = v[:, indices]
# Adjust for relative positions
if indices.numel() < q_chunk.shape[1]:
# Pad if needed
pad_size = q_chunk.shape[1] - indices.numel()
k_window = F.pad(k_window, (0, 0, 0, 0, 0, pad_size))
v_window = F.pad(v_window, (0, 0, 0, 0, 0, pad_size))
attn = torch.einsum('bqhd,bkhd->bhqk', q_chunk, k_window) * scale
attn = F.softmax(attn, dim=-1)
chunk_out = torch.einsum('bhqk,bkhd->bqhd', attn, v_window)
output[:, start:end] = chunk_out
return output.reshape(B, T, C)
Flash Attention with Sliding Window
from flash_attn.flash_attn_interface import flash_attn_func
class FlashSlidingWindowAttention(nn.Module):
"""
Use Flash Attention 2 with sliding window support.
FlashAttention natively supports window_mask for efficient sliding.
"""
def __init__(self, d_model, n_heads, window_size=512):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
assert d_model == n_heads * self.head_dim
self.qkv = nn.Linear(d_model, 3 * d_model)
self.o_proj = nn.Linear(d_model, d_model)
self.window_size = window_size
def forward(self, x, cu_seqlens=None):
B, T, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.split(C, dim=-1)
q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
# Flash Attention with sliding window
# cu_seqlens defines sequence boundaries for variable-length batches
out = flash_attn_func(
q, k, v,
causal=True,
window_size=(self.window_size, self.window_size)
)
return self.o_proj(out.transpose(1, 2).reshape(B, T, C))
Failure Mode: Incorrect Window Boundary
# BUG: Window extends beyond sequence or goes negative
def compute_sliding_attn_broken(q, k, v, window_size):
B, T, H, D = q.shape
scale = D ** -0.5
for i in range(T):
# WRONG: window_start can be negative
window_start = i - window_size # Can be -512 for i=0
# Should be: window_start = max(0, i - window_size)
# Otherwise negative indexing retrieves wrong tokens
window_end = i + 1 # Causal, only look back
q_i = q[:, i:i+1] # (B, 1, H, D)
k_window = k[:, window_start:window_end]
# Negative start retrieves from end of sequence!
# ...
EXERCISE
Implement a multi-layer sliding window attention stack where each layer has doubling dilation (16, 32, 64, 128, ...). Visualize the effective attention pattern for a 2048-token sequence to confirm global coverage.