RUNLOCALAIv38
->Will it run?Best GPUCompareTroubleshootStartLearnPulseModelsHardwareToolsBench
Run check
RUNLOCALAI

Independently operated catalog for local-AI hardware and software. Hand-written verdicts. Source-cited claims. Reproducible commands when we have them.

OP·Fredoline Eruo
DIR
  • Models
  • Hardware
  • Tools
  • Benchmarks
TOOLS
  • Will it run?
  • Compare hardware
  • Cost vs cloud
  • Choose my GPU
  • Prompting kits
  • Quick answers
REF
  • All buyer guides
  • Learn local AI
  • Methodology
  • Glossary
  • Errors KB
  • Trust
EDITOR
  • About
  • Author
  • How we make money
  • Editorial policy
  • Contact
LEGAL
  • Privacy
  • Terms
  • Sitemap
MAIL · MONTHLY DIGEST
Get monthly local AI changes
Monthly recap. No spam.
DISCLOSURE

Some links on this site are affiliate links (Amazon Associates and other first-class retailers). When you buy through them, we earn a small commission at no extra cost to you. Affiliate links do not influence our verdicts — there are cards we rate highly that we don't have affiliate relationships with, and cards that sell well that we refuse to recommend. Read more →

© 2026 runlocalai.coIndependently operated
RUNLOCALAI · v38
  1. >
  2. Home
  3. /Learn
  4. /Courses
  5. /Custom LLM Architecture Design
  6. /Ch. 4
Custom LLM Architecture Design

04. FlashAttention Implementation

Chapter 4 of 24 · 15 min
KEY INSIGHT

FlashAttention solves attention's memory bottleneck by computing in tiles that fit in fast GPU SRAM. This reduces memory from O(N²) to O(N) at the cost of slightly slower compute—a worthwhile trade-off when memory is the limiting factor.

FlashAttention achieves Attention in O(N²) time but O(N) memory by computing attention in tiles that fit in GPU SRAM, avoiding materialization of the full N×N attention matrix.

import torch
from torch.nn.attention import ScaledDotProductAttention

# PyTorch 2.0+ has FlashAttention built in
class FlashAttentionWrapper(torch.nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)
        self.W_o = nn.Linear(d_model, d_model, bias=False)
        
        # Use PyTorch's built-in FlashAttention
        self.attn = ScaledDotProductAttention(
            dropout_p=dropout,
            scale=self.d_k ** -0.5,
        )
    
    def forward(self, x, mask=None, is_causal=True):
        batch_size, seq_len, d_model = x.shape
        
        Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        
        # PyTorch's SDPA uses FlashAttention when available
        output = self.attn(Q, K, V, attn_mask=mask, is_causal=is_causal)
        
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
        return self.W_o(output)

Failure mode: Incompatible CUDA version. FlashAttention requires CUDA 11.6+ for the original implementation or PyTorch 2.0+ with compatible CUDA and cuDNN versions. On older GPUs or wrong CUDA versions, PyTorch falls back to naive attention, losing the memory efficiency.

Failure mode: Block size mismatches. FlashAttention's tile size (typically 64 or 128) must align with thread block sizes. Misaligned dimensions cause subtle correctness issues or performance degradation.

Verification: Check if FlashAttention is active:

def check_flash_attention_available():
    try:
        q = torch.randn(2, 8, 128, 64, device='cuda', dtype=torch.float16)
        k = torch.randn(2, 8, 128, 64, device='cuda', dtype=torch.float16)
        v = torch.randn(2, 8, 128, 64, device='cuda', dtype=torch.float16)
        
        with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPABackend.FLASH_ATTENTION]):
            _ = F.scaled_dot_product_attention(q, k, v)
        return True
    except:
        return False
EXERCISE

Benchmark naive attention vs. FlashAttention for sequence lengths 512, 1024, 2048, and 4096. Measure both peak memory usage and inference time. Verify correctness by comparing outputs.

← Chapter 3
Multi-Head Attention
Chapter 5 →
Rotary Position Embedding