09. FlashAttention
Standard attention computation stores intermediate results in HBM (High Bandwidth Memory), creating memory bandwidth bottlenecks. For a 4096-token context, standard attention materializes a 4096×4096 attention matrix (~64MB per layer), only to immediately discard most values after softmax.
FlashAttention solves this by computing attention in tiles that fit in SRAM (on-chip memory), dramatically reducing HBM accesses. The algorithm leverages the fact that softmax can be computed incrementally across tiles, avoiding full matrix materialization.
# Standard attention O(n²) memory complexity
import torch.nn.functional as F
def standard_attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
weights = F.softmax(scores, dim=-1)
return torch.matmul(weights, V)
# Memory: O(n²) for attention scores matrix
# Flash attention O(n) memory complexity
from flash_attn import flash_attn_func
def flash_attention(Q, K, V, causal=True):
# Q, K, V: (batch, seqlen, nheads, headdim)
return flash_attn_func(
Q, K, V,
causal=causal, # Causal masking for autoregressive
dropout_p=0.0, # Dropout probability
softmax_scale=None, # Automatic if None
window_size=(-1, -1), # Sliding window (-1, -1)=full attention
)
# Memory: O(n) for KV cache, SRAM tiles never exceed threshold
FlashAttention-2 improves upon FlashAttention with better parallelism across attention heads and sequence length dimension. CUDA kernels achieve 230-250 TFLOP/s on A100 GPUs, approaching hardware theoretical limits.
Installation and verification:
# Install flash-attn
pip install flash-attn --no-build-isolation
# Verify installation
python -c "from flash_attn import flash_attn_func; print('FlashAttention available')"
# Check CUDA version compatibility
python -c "import torch; print(f'CUDA {torch.version.cuda}, Flash attention build: ok')"
Memory savings translate directly to longer context capability. Standard attention on a 70B model at 32K context requires ~128GB for KV cache alone. FlashAttention reduces this to ~16GB, enabling longer contexts within the same VRAM budget.
Integration with HuggingFace transformers:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
attn_implementation="flash_attention_2", # Enable Flash Attention
torch_dtype=torch.float16,
device_map="auto"
)
Hardware compatibility:
- NVIDIA A100/A800: Full support, maximum performance
- NVIDIA RTX 3090/4090: Supported, slightly reduced throughput
- NVIDIA RTX 3080/2080: Limited support (ampere or newer required)
- AMD GPUs: No support (ROCm version in development)
# Check GPU compute capability
python -c "import torch; print(torch.cuda.get_device_capability())"
# (8, 6) = compute capability 8.6 (RTX 40 series)
# (7, 5) = compute capability 7.5 (RTX 30 series)
Measure inference latency with FlashAttention enabled vs disabled at increasing context lengths (1K, 4K, 16K, 32K). Calculate the context length where standard attention becomes infeasible due to memory constraints.