Also known as: flash attention, FA2, FA3, flash-attn
TL;DR
FlashAttention is an I/O-aware attention kernel that tiles the computation in SRAM and fuses the softmax, avoiding the need to materialize the N×N attention matrix in HBM.
FlashAttention is an I/O-aware GPU kernel for scaled dot-product attention . It computes exactly the same output as a textbook attention implementation — softmax(QK^T / sqrt(d_k)) V — but reorganizes the operations to never write the N x N attention matrix to GPU main memory. By tiling the inputs into blocks that fit in fast on-chip SRAM and fusing the softmax across blocks, it cuts attention’s memory traffic by an order of magnitude and unlocks long-context inference at scale.
The bottleneck it solves
Naive attention has three steps: compute , apply softmax row-wise to get , then multiply to produce the output . The intermediate matrix (and ) is — quadratic in sequence length. At and fp16, that’s 128 MB per head per layer, written to HBM (slow GPU global memory) and read back twice.
The compute is fast; the memory traffic is the wall. Modern GPUs offer ~200 FLOPS per byte loaded from HBM. Naive attention nowhere near saturates the FLOPS — it spends most of its time waiting for memory.
The trick: tiling and online softmax
FlashAttention partitions , , into blocks small enough to fit in SRAM. For each pair of (Q-block, K-block), it computes a partial block in SRAM, runs the relevant portion of the softmax against the running normalizer, multiplies by the V-block, and accumulates the result. The full attention matrix never materializes in HBM — only the input/output tensors do.
The technically interesting part is the softmax. Softmax requires the row’s full sum of exponentials in the denominator, which seems to require seeing the whole row at once. FlashAttention uses an online softmax trick: process blocks left-to-right, keeping a running max and running sum . When a new block arrives with a larger max, rescale the prior accumulators by the exponential of the difference. This produces the exact softmax output, computed incrementally, with total memory.
For each row, define the partial state after processing blocks. When a new block arrives with row-max and row-sum over its block:
This is mathematically identical to running softmax over the concatenated rows, just computed block-at-a-time. The rescaling factors absorb the change of normalizer when a new max appears.
What it actually delivers
On A100 GPUs, FlashAttention 2 achieves around 70% of the theoretical peak FLOPS for attention — versus around 30-40% for naive PyTorch attention, which is bandwidth-bound. End-to-end, this translates to 2-4x faster training and 2-3x faster inference for transformer models, with the gains larger at longer contexts.
The bigger story is what it makes possible. Long-context training (32K, 128K, 1M tokens) is essentially impossible without FlashAttention or a close variant — the memory and bandwidth costs would be untenable. Every modern long-context LLM is trained and served with some flavor of FlashAttention or its descendants (xFormers, PyTorch SDPA, the various TensorRT-LLM kernels).
How it composes with the rest of the stack
FlashAttention stacks cleanly with grouped-query-attention (the kernel doesn’t care how many K/V heads there are), with kv-cache (the prefill phase uses FlashAttention; the decode phase uses a specialized variant called Flash-Decoding), and with quantization. Most production inference engines — vLLM, TensorRT-LLM, SGLang — use FlashAttention as the default attention kernel and only fall back to naive attention when the input doesn’t fit the kernel’s constraints.
FlashAttention is the canonical example of how modern ML performance is dominated by memory hierarchy rather than raw compute. The math of attention hasn’t changed since 2017; reordering when each tensor lives in which memory level was worth a 2-4x speedup across the entire field.
Go further
What's the difference between FlashAttention 1, 2, and 3?
FA1 (2022) introduced the tiling and softmax-fusion idea, getting 2-4x speedup on training. FA2 (2023) parallelized differently across the sequence dimension to better saturate modern GPUs, doubling FA1's throughput. FA3 (2024) targets Hopper-class GPUs with async warp-specialization and FP8, hitting 75%+ of theoretical FLOPS. FA2 is the practical default in 2026; FA3 if you're on H100/H200.
Does FlashAttention change attention's mathematical output?
No — FlashAttention is exact. The output is bit-equivalent to standard attention up to floating-point reorder-rounding. It changes the computation order, not the result. This is what distinguishes it from sparse-attention or low-rank approximations, which compute a different (cheaper) function.
Why doesn't materializing the attention matrix work at long context?
Memory and bandwidth. At 128K context, the N×N matrix in fp16 is 32 GB just for one head, one layer. You can't fit it in HBM, much less SRAM. Without tiling you'd have to write and re-read every entry from HBM, dominating the runtime. FlashAttention's tile-then-reduce structure means the full matrix never exists anywhere — it's computed and consumed block by block.