RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /Custom LLM Architecture Design
  6. /Ch. 14
Custom LLM Architecture Design

14. Custom Attention Patterns

Chapter 14 of 24 · 25 min
KEY INSIGHT

Custom attention patterns trade generality for efficiency. The key is ensuring the pattern matches the data structure—sparse for text with long-range dependencies, stride-based for spatial data, block-based for very long sequences.

Standard attention patterns rarely match real-world use cases. Custom patterns enable domain-specific optimization.

Sparse Attention Patterns

class SparseAttention(nn.Module):
    """
    Implement band + global attention pattern.
    Each token attends to:
    - Nearby tokens within a band width
    - Specific global tokens (e.g., CLS, special markers)
    """
    def __init__(self, d_model, n_heads, band_width=128, n_global=2):
        super().__init__()
        self.band_width = band_width
        self.n_global = n_global
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim)
        k = k.view(B, T, self.n_heads, self.head_dim)
        v = v.view(B, T, self.n_heads, self.head_dim)
        
        # Create sparse attention mask
        mask = torch.zeros(T, T, device=x.device)
        
        # Band: attend to nearby positions
        for i in range(T):
            lo = max(0, i - self.band_width)
            hi = min(T, i + self.band_width + 1)
            mask[i, lo:hi] = 1.0
        
        # Global: first n_global tokens attend all positions
        mask[:, :self.n_global] = 1.0
        mask[:self.n_global, :] = 1.0
        
        # Compute attention with mask
        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        
        # Apply mask (set masked positions to -inf before softmax)
        mask = mask.unsqueeze(0).unsqueeze(0)  # (1, 1, T, T)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        
        attn = attn.softmax(dim=-1)
        out = attn @ v
        
        return out.transpose(1, 2).reshape(B, T, C)

Block-Sparse Attention for Long Sequences

class BlockSparseAttention(nn.Module):
    """
    Chunk-based attention: attend within blocks, cross-block via reduction.
    Memory: O((T/B)² + T) instead of O(T²)
    """
    def __init__(self, d_model, n_heads, block_size=128, n_global_blocks=2):
        super().__init__()
        self.block_size = block_size
        self.n_global_blocks = n_global_blocks
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        
        # Global token aggregation for cross-block communication
        self.global_proj = nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        
        n_blocks = (T + self.block_size - 1) // self.block_size
        pad_T = n_blocks * self.block_size
        pad_x = torch.nn.functional.pad(x, (0, 0, 0, pad_T - T))
        
        # Reshape into blocks: (B, n_blocks, block_size, C)
        q = q.view(B, n_blocks, self.block_size, -1)
        k = k.view(B, n_blocks, self.block_size, -1)
        v = v.view(B, n_blocks, self.block_size, -1)
        
        # Local attention within blocks
        attn = torch.einsum('bnid,bnjd->bnij', q, k) / (self.head_dim ** 0.5)
        attn = attn.softmax(dim=-1)
        local_out = torch.einsum('bnij,bnjd->bnid', attn, v)
        
        # Global summary per block (mean pooling)
        global_k = k.mean(dim=2)  # (B, n_blocks, H, D)
        global_v = v.mean(dim=2)
        
        # Cross-block attention via global summaries
        # (Implementation simplified - production code needs iterative refinement)
        
        return local_out.reshape(B, T, C)

Stride Patterns for Vision-Language

class StridedAttention(nn.Module):
    """
    Stride-based attention pattern useful for image-like 2D data.
    Token at position (i,j) attends to (i±s, j±s) where s is stride.
    """
    def __init__(self, d_model, n_heads, stride=2):
        super().__init__()
        self.stride = stride
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
    
    def create_stride_mask(self, T):
        """Create mask for stride-based sparse attention"""
        mask = torch.zeros(T, T)
        for i in range(T):
            for j in range(T):
                # Stride pattern: |i - j| is multiple of stride
                if (abs(i - j)) % self.stride == 0:
                    mask[i, j] = 1.0
        return mask
    
    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        q = q.view(B, T, self.n_heads, self.head_dim)
        k = k.view(B, T, self.n_heads, self.head_dim)
        v = v.view(B, T, self.n_heads, self.head_dim)
        
        mask = self.create_stride_mask(T).to(x.device)
        
        scale = self.head_dim ** -0.5
        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = attn.softmax(dim=-1)
        
        return (attn @ v).transpose(1, 2).reshape(B, T, C)

Failure Mode: Mask Non-Deterministic Behavior

# BUG: Mask not applied consistently across heads
class BrokenSparseAttn(nn.Module):
    def forward(self, x):
        q, k, v = self.qkv(x).chunk(3, dim=-1)
        B, T, H, D = q.shape
        
        # Apply mask only to first head - silently incorrect
        attn = (q @ k.transpose(-2, -1)) / (D ** 0.5)
        attn[:, :, :, self.block_size:] = float('-inf')  # Only first head
        
        attn = attn.softmax(dim=-1)
        return (attn @ v).reshape(B, T, -1)
        # Other heads see full attention, defeating sparsity purpose
EXERCISE

Implement a "dilated" attention pattern where token i attends to tokens at positions i, i±s, i±2s, i±4s, etc. (doubling gaps). Verify attention patterns using torch.allclose against a reference implementation.

← Chapter 13
Comparing Architectures
Chapter 15 →
Grouped Query Attention