04. FlashAttention Implementation
FlashAttention achieves Attention in O(N²) time but O(N) memory by computing attention in tiles that fit in GPU SRAM, avoiding materialization of the full N×N attention matrix.
import torch
from torch.nn.attention import ScaledDotProductAttention
# PyTorch 2.0+ has FlashAttention built in
class FlashAttentionWrapper(torch.nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_q = nn.Linear(d_model, d_model, bias=False)
self.W_k = nn.Linear(d_model, d_model, bias=False)
self.W_v = nn.Linear(d_model, d_model, bias=False)
self.W_o = nn.Linear(d_model, d_model, bias=False)
# Use PyTorch's built-in FlashAttention
self.attn = ScaledDotProductAttention(
dropout_p=dropout,
scale=self.d_k ** -0.5,
)
def forward(self, x, mask=None, is_causal=True):
batch_size, seq_len, d_model = x.shape
Q = self.W_q(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
# PyTorch's SDPA uses FlashAttention when available
output = self.attn(Q, K, V, attn_mask=mask, is_causal=is_causal)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, d_model)
return self.W_o(output)
Failure mode: Incompatible CUDA version. FlashAttention requires CUDA 11.6+ for the original implementation or PyTorch 2.0+ with compatible CUDA and cuDNN versions. On older GPUs or wrong CUDA versions, PyTorch falls back to naive attention, losing the memory efficiency.
Failure mode: Block size mismatches. FlashAttention's tile size (typically 64 or 128) must align with thread block sizes. Misaligned dimensions cause subtle correctness issues or performance degradation.
Verification: Check if FlashAttention is active:
def check_flash_attention_available():
try:
q = torch.randn(2, 8, 128, 64, device='cuda', dtype=torch.float16)
k = torch.randn(2, 8, 128, 64, device='cuda', dtype=torch.float16)
v = torch.randn(2, 8, 128, 64, device='cuda', dtype=torch.float16)
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPABackend.FLASH_ATTENTION]):
_ = F.scaled_dot_product_attention(q, k, v)
return True
except:
return False
Benchmark naive attention vs. FlashAttention for sequence lengths 512, 1024, 2048, and 4096. Measure both peak memory usage and inference time. Verify correctness by comparing outputs.