16. Sliding Window Attention

Chapter 16 of 24 · 25 min

Sliding window attention restricts each token's attention span to a local window. This pattern matches the observation that most linguistic dependencies are local—making it efficient for long documents.

Standard Sliding Window Implementation

class SlidingWindowAttention(nn.Module):
    """
    Sliding window attention with configurable window size.
    Each token attends to window_size tokens before and after.
    Complexity: O(window_size * T) instead of O(T²)
    """
    def __init__(self, d_model, n_heads, window_size=512):
        super().__init__()
        self.window_size = window_size
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.o_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        B, T, C = x.shape
        
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim)
        k = k.view(B, T, self.n_heads, self.head_dim)
        v = v.view(B, T, self.n_heads, self.head_dim)
        
        # Create causal + sliding window mask
        # For each position i, allow positions [i-window, i+window]
        # With causal, positions > i are already masked
        scale = self.head_dim ** -0.5
        
        output = torch.zeros_like(q)
        
        # Process in chunks to avoid O(T²) memory
        for start in range(0, T, self.window_size):
            end = min(start + self.window_size, T)
            
            # For queries in this chunk, compute attention
            q_chunk = q[:, start:end]  # (B, w, H, D)
            
            # Keys in window around this chunk
            kv_start = max(0, start - self.window_size)
            kv_end = min(T, end + self.window_size)
            
            k_window = k[:, kv_start:kv_end]
            v_window = v[:, kv_start:kv_end]
            
            # Compute attention for this chunk
            attn = torch.einsum('bqhd,bkhd->bhqk', q_chunk, k_window) * scale
            
            # Causal mask for this local window
            # Relative positions within the chunk
            causal = torch.tril(torch.ones(end - start, kv_end - kv_start, device=x.device))
            attn = attn.masked_fill(causal == 0, float('-inf'))
            
            attn = F.softmax(attn, dim=-1)
            
            chunk_out = torch.einsum('bhqk,bkhd->bqhd', attn, v_window)
            output[:, start:end] = chunk_out
        
        return output.reshape(B, T, C)

Dilated Sliding Window (Mistral-style)

class DilatedSlidingWindowAttention(nn.Module):
    """
    Dilated sliding window - like dilated convolution but for attention.
    Layer 0: window_size=16
    Layer 1: window_size=16, dilation=2 (covers 0, 2, 4, 6, ...)
    Layer 2: window_size=16, dilation=4
    Layer 3: window_size=16, dilation=8
    Result: receptive field of 4096 with only O(T * window_size) compute
    """
    def __init__(self, d_model, n_heads, window_size=16, dilation=1):
        super().__init__()
        self.window_size = window_size
        self.dilation = dilation
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.o_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim)
        k = k.view(B, T, self.n_heads, self.head_dim)
        v = v.view(B, T, self.n_heads, self.head_dim)
        
        output = torch.zeros_like(q)
        scale = self.head_dim ** -0.5
        
        for start in range(0, T, self.window_size):
            end = min(start + self.window_size, T)
            
            q_chunk = q[:, start:end]
            
            # Dilation means we skip positions
            # Window covers positions: start, start+dilation, start+2*dilation, ...
            kv_start = max(0, start - self.dilation * self.window_size)
            kv_end = min(T, end)
            
            # Select with dilation step
            indices = torch.arange(kv_start, kv_end, self.dilation, device=x.device)
            indices = indices.clamp(0, T - 1)
            
            k_window = k[:, indices]
            v_window = v[:, indices]
            
            # Adjust for relative positions
            if indices.numel() < q_chunk.shape[1]:
                # Pad if needed
                pad_size = q_chunk.shape[1] - indices.numel()
                k_window = F.pad(k_window, (0, 0, 0, 0, 0, pad_size))
                v_window = F.pad(v_window, (0, 0, 0, 0, 0, pad_size))
            
            attn = torch.einsum('bqhd,bkhd->bhqk', q_chunk, k_window) * scale
            attn = F.softmax(attn, dim=-1)
            
            chunk_out = torch.einsum('bhqk,bkhd->bqhd', attn, v_window)
            output[:, start:end] = chunk_out
        
        return output.reshape(B, T, C)

Flash Attention with Sliding Window

from flash_attn.flash_attn_interface import flash_attn_func

class FlashSlidingWindowAttention(nn.Module):
    """
    Use Flash Attention 2 with sliding window support.
    FlashAttention natively supports window_mask for efficient sliding.
    """
    def __init__(self, d_model, n_heads, window_size=512):
        super().__init__()
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        assert d_model == n_heads * self.head_dim
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.o_proj = nn.Linear(d_model, d_model)
        
        self.window_size = window_size
    
    def forward(self, x, cu_seqlens=None):
        B, T, C = x.shape
        
        qkv = self.qkv(x)
        q, k, v = qkv.split(C, dim=-1)
        
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        
        # Flash Attention with sliding window
        # cu_seqlens defines sequence boundaries for variable-length batches
        out = flash_attn_func(
            q, k, v,
            causal=True,
            window_size=(self.window_size, self.window_size)
        )
        
        return self.o_proj(out.transpose(1, 2).reshape(B, T, C))

Failure Mode: Incorrect Window Boundary

# BUG: Window extends beyond sequence or goes negative
def compute_sliding_attn_broken(q, k, v, window_size):
    B, T, H, D = q.shape
    scale = D ** -0.5
    
    for i in range(T):
        # WRONG: window_start can be negative
        window_start = i - window_size  # Can be -512 for i=0
        
        # Should be: window_start = max(0, i - window_size)
        # Otherwise negative indexing retrieves wrong tokens
        
        window_end = i + 1  # Causal, only look back
        
        q_i = q[:, i:i+1]  # (B, 1, H, D)
        k_window = k[:, window_start:window_end]
        # Negative start retrieves from end of sequence!
        
        # ...
EXERCISE

Implement a multi-layer sliding window attention stack where each layer has doubling dilation (16, 32, 64, 128, ...). Visualize the effective attention pattern for a 2048-token sequence to confirm global coverage.