ATTENTION, FULLY ASSEMBLED
Section 13.3
03

FlashAttention — tile + online softmax + running output

This is the section the whole book has been building toward. We had the cost problem in §13.1: the N × N attention score matrix dominates memory at long contexts. We had the algorithmic key in §12.3: online softmax can stream over blocks while maintaining (m, ℓ). We had the tiling pattern in Ch.2 §2 and the microkernel pattern in Ch.2 §4. FlashAttention puts all four together. The algorithm — Dao 2022 “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness” — tiles Q, K, V into blocks small enough to fit in GPU SRAM, maintains per-query-row running state (m, ℓ, O), and produces output bit-equal to naïve attention. Memory drops from O(N²) to O(N). HBM traffic drops by the SRAM-vs-N ratio. Wall-clock attention speeds up 3-10×. Critically, nothing about the math changes — the output is exact, not an approximation. This is what makes 200K-context models practical.

What goes in SRAM, what stays in HBM

GPU memory has two tiers. HBM is the big slow tier — 80 GB on an H100 at 3 TB/s. SRAM is the small fast tier — about 228 KB per streaming multiprocessor (SM) on an H100, 33 MB total across the chip. SRAM is 10× faster bandwidth, ~1000× lower latency.

For naïve attention, the N × N score matrix doesn’t fit in SRAM (at N = 32K, it’s 4 GB; at N = 4K, it’s 64 MB — way bigger than SRAM). So each row of attention has to:

  1. Read Q_i from HBM, K from HBM, write S = Q · Kᵀ to HBM.
  2. Read S from HBM, compute softmax in SRAM, write P back to HBM.
  3. Read P from HBM, read V from HBM, compute P · V, write O to HBM.

The N × N score matrix gets written, read, written again, read again. The wall-clock attention time on a GPU is dominated by these HBM ↔ SRAM round trips, not by the FLOPs. Compute is “free” relative to memory traffic on modern GPUs.

The algorithm

FlashAttention’s structure: outer loop over query blocks, inner loop over key/value blocks. Each query block carries running state (m, ℓ, O) that gets updated as key/value blocks are consumed.

For each query tile Q_i of size B_q × d: Load Q_i into SRAM. Initialise: m_i = -∞, ℓ_i = 0, O_i = 0 (B_q × d output tile) For each key/value tile (K_j, V_j) of size B_k × d: Load K_j and V_j into SRAM. 1. Partial scores: S_ij = Q_i · K_jᵀ / √d shape B_q × B_k 2. Block max: m_ij = max(S_ij, axis=keys) shape B_q (per query row) 3. Block exp: P_ij = exp(S_ij − m_ij) shape B_q × B_k 4. Block sum: ℓ_ij = Σ_k P_ij[·, k] shape B_q 5. New running max: m_new = max(m_i, m_ij) shape B_q 6. Rescale running state (the §12.3 identity, extended to O): α = exp(m_i − m_new) β = exp(m_ij − m_new) ℓ_i ← α · ℓ_i + β · ℓ_ij O_i ← α · O_i + β · (P_ij · V_j) m_i ← m_new After all K/V tiles: O_i ← O_i / ℓ_i (one normalisation) Write O_i back to HBM.

The key insight (the one extension beyond §12.3 online softmax): O gets rescaled by the SAME α factor as ℓ. This works because O is “ℓ · (softmax-weighted V)” before the final normalisation — the same α that rescales ℓ to the new max basis also rescales the accumulated weighted-V to the same basis. The β factor brings the new block’s contribution into that basis.

After all K/V tiles are processed for one query block, divide by ℓ once. That single division is the only place the softmax denominator shows up; it produces the same probabilities the naïve softmax would.

FlashAttention is the entire algorithm. Three matrices to load per tile (Q_i, K_j, V_j); three scratch matrices to keep (S_ij, P_ij, the partial P·V); two running scalars per query row (m_i, ℓ_i); one output tile (O_i).

— think, then check —

Per-row state:

  • m — scalar, the running max of scores seen so far (1 element).
  • — scalar, the running sum of exp(score − m) so far (1 element).
  • O — d-dim vector, the running (un-normalised) softmax-weighted sum of V vectors seen so far (d elements).

Total per row: 2 + d scalars (for d = 128, that’s 130 floats per query row).

Why (m, ℓ) alone isn’t enough:

(m, ℓ) by itself reconstructs the softmax DENOMINATOR. But the final attention output is the softmax-weighted sum of V vectors. To compute that, you need to know the weighted sum of V — and that weighting depends on the per-element exp(score − m) values, not just their sum.

If we only kept (m, ℓ) and not O, then after all K/V blocks were processed, we’d know exactly what the softmax weights should be — but we wouldn’t have applied them to the V vectors. We’d have to make a SECOND pass over K/V (reading them from HBM again) to compute the weighted sum, with the now-final softmax weights. That second pass defeats the whole point of FlashAttention — it puts back the HBM traffic we were trying to avoid.

Carrying O alongside (m, ℓ) lets us do everything in one pass. Each block contributes its partial P_blk · V_blk to the running O, rescaled into the current max basis. When the running max changes, O gets rescaled by the same α factor as ℓ. At the end, one division O / ℓ produces the final attention output. Single-pass, exact, in O(N) memory.

Now make it run

The C kernel implements the exact algorithm above and compares output bit-by-bit to a naïve full-matrix attention reference:

flash_attention.c — attention_flash C · tile-and-stream attention
/* ------- FlashAttention (tile + online softmax + running O) --------- */
static void attention_flash(
    const float* Q,
    const float* K,
    const float* V,
    float* O,
    int N, int d, int Bq, int Bk)
{
    const float scale = 1.0f / sqrtf((float)d);
    /* per-query-row running state */
    float* m_row = malloc((size_t)N * sizeof(float));
    float* l_row = malloc((size_t)N * sizeof(float));
    for (int i = 0; i < N; i++) { m_row[i] = -INFINITY; l_row[i] = 0.0f; }
    memset(O, 0, (size_t)N * d * sizeof(float));

    /* per-block tile scratch (kept tiny — would fit in SRAM in a GPU impl) */
    float* S_blk = malloc((size_t)Bq * Bk * sizeof(float));
    float* P_blk = malloc((size_t)Bq * Bk * sizeof(float));
    float* m_blk = malloc((size_t)Bq * sizeof(float));   /* per-row of this Q-block, this K-tile's max */
    float* l_blk = malloc((size_t)Bq * sizeof(float));   /* per-row of this Q-block, this K-tile's sum */

    /* Outer loop: query blocks. Within each, sweep K/V blocks; update running state. */
    for (int qi = 0; qi < N; qi += Bq) {
        int qend = qi + Bq; if (qend > N) qend = N;
        int qr = qend - qi;   /* rows in this Q-tile */

        for (int kj = 0; kj < N; kj += Bk) {
            int kend = kj + Bk; if (kend > N) kend = N;
            int kc = kend - kj;   /* cols in this K/V-tile */

            /* 1. Partial scores  S_blk = Q_tile · K_tileᵀ · scale */
            for (int ii = 0; ii < qr; ii++) {
                for (int jj = 0; jj < kc; jj++) {
                    float s = 0;
                    for (int l = 0; l < d; l++)
                        s += Q[(qi+ii)*d+l] * K[(kj+jj)*d+l];
                    S_blk[ii*Bk+jj] = s * scale;
                }
            }
            /* 2-4. Block max, exp, sum per row */
            for (int ii = 0; ii < qr; ii++) {
                float mx = -INFINITY;
                for (int jj = 0; jj < kc; jj++)
                    if (S_blk[ii*Bk+jj] > mx) mx = S_blk[ii*Bk+jj];
                m_blk[ii] = mx;
                float sum = 0;
                for (int jj = 0; jj < kc; jj++) {
                    float p = expf(S_blk[ii*Bk+jj] - mx);
                    P_blk[ii*Bk+jj] = p; sum += p;
                }
                l_blk[ii] = sum;
            }
            /* 5-6. Running state update */
            for (int ii = 0; ii < qr; ii++) {
                int row = qi + ii;
                float m_new = (m_row[row] > m_blk[ii]) ? m_row[row] : m_blk[ii];
                float alpha = expf(m_row[row] - m_new);   /* rescale factor for old O, ℓ */
                float beta  = expf(m_blk[ii] - m_new);    /* scale factor for new block */

                /* O_row  ←  α · O_row  +  β · (P_blk · V_tile) */
                for (int l = 0; l < d; l++) {

The verification sweeps every combination of query block size B_q ∈ 128 and key block size B_k ∈ 128, comparing against naïve attention output (the reference). Output:

FlashAttention vs naïve — N=128, d=32
naïve attention output mean |O| = 0.1036
(diff is shown in absolute terms; output element magnitude is ~0.10)

B_q        B_k        max |flash − naive|
8          8          2.384e-07
8          16         2.682e-07
8          32         2.384e-07
8          64         2.533e-07
8          128        2.682e-07
16         8          2.384e-07
...
128        64         2.533e-07
128        128        2.682e-07

Across all 25 block-size combinations, the FlashAttention output matches naïve attention to ~3 × 10⁻⁷ — float roundoff. Output magnitudes are ~0.1, so the relative error is around 10⁻⁶. The N × N score matrix is never materialised in the FlashAttention kernel; only B_q × B_k scratch (the largest, 128 × 128 = 64 KB, fits in SRAM).

— think, then check —

Why block size doesn’t affect correctness: the FlashAttention algorithm computes, at the end, the EXACT same final softmax weights and weighted sum of V that naïve attention does — just incrementally.

The online-softmax recurrence (§12.3) guarantees that for any block partitioning of the input scores, the running (m, ℓ) computed by the rescale-and-combine rule is mathematically equal to the (m, ℓ) you’d get from naïve full-batch computation. Extension to O: the running output O is “ℓ · weighted-V” — rescaled exactly the same way as ℓ, by the same α and β factors. So the running O is mathematically equal to “ℓ · (softmax weights as if computed naïvely so far) · V”. After the final O / ℓ division, you get the exact softmax-weighted V sum.

This is the same property as Ch.2 §2’s matmul tiling: the three axes (M, N, K) can be partitioned freely; M and N produce independent partial outputs that don’t interact; K produces partial sums that add. FlashAttention extends this to softmax: instead of “partial sums that add,” it’s “partial state that combines via a rescale rule.” The state grows from 0 (matmul) to 2 + d (FlashAttention), but the principle is the same: tile freely, combine correctly.

What block size DOES change:

  • Memory footprint. Larger blocks need more SRAM. Largest practical block on H100 is ~128 × 128 in fp16 (32 KB) plus surrounding state.
  • Parallelism. Smaller B_q means more query tiles, which can be assigned to more streaming multiprocessors (more parallelism). Smaller B_k means more inner-loop iterations per query tile (less parallelism per query).
  • HBM accesses. Total HBM traffic for K, V is O(N · d) — they’re read once per query tile times (N / B_q) query tiles, so total reads = (N / B_q) · N · d. Smaller B_q means more reads. Larger B_q is better for HBM traffic, but bounded by SRAM capacity.
  • Numerical roundoff. The 2.4-2.7 × 10⁻⁷ differences across block sizes in the output ARE real — different summation orders accumulate roundoff slightly differently. But they’re all at the level of float32 epsilon, so any block size is “correct enough.”

The practical choice: pick the largest (B_q, B_k) that fit in SRAM after accounting for the running state. On H100 fp16, typical FlashAttention-2 picks B_q = B_k = 128, with d = 64-128. On A100 (smaller SRAM), B_q = B_k = 64.

The HBM-traffic accounting

The wall-clock speedup of FlashAttention comes from collapsed HBM access counts. Let’s count them.

Naïve attention HBM accesses (per attention layer): Read Q: N · d Read K: N · d Write S = Q · Kᵀ: N · N ← the killer Read S: N · N Write P (softmax): N · N ← also killer Read P: N · N Read V: N · d Write O: N · d Total: ~ 4 N² + 4 N d HBM accesses. FlashAttention HBM accesses: Read Q: N · d (each query tile loaded once) Read K: N · d (each K element loaded (N/B_q) times across query tiles) ← but with reasonable B_q, this is bounded. Read V: same as K. Write O: N · d Total: ~ 2 N · d · (N / B_q) for K, V reads + 2 N · d for Q, O = O(N² · d / B_q) Compared to naïve's O(N²) for the score matrix alone: a factor of d / B_q. For d = 128 and B_q = 128: N² / 128 N² = factor of 128 fewer HBM operations on the dominant term.

The empirical wins from Dao 2022:

— think, then check —

Memory-bound regime: Modern GPUs have ~10× the FLOPs they need given their memory bandwidth. For matrix-multiply-heavy operations, this is fine because there’s lots of compute reuse per byte loaded (each loaded element of A in C = A·B is reused for the entire row of B). For softmax-over-N-things, there’s almost no compute reuse — each loaded score gets one exp() and one add, then is done. So softmax is bandwidth-bound, not compute-bound.

The naïve score matrix is the killer: Writing N² scores to HBM, then reading them back to do softmax, then reading them AGAIN to multiply by V — that’s 4N² HBM accesses for the dominant intermediate. At N = 16K and fp16, that’s 4 · 16K · 16K · 2 = 2 GB of HBM traffic per attention layer per forward pass. Multiplied by 32 heads × 32 layers × batch size, you saturate HBM bandwidth long before compute.

FlashAttention’s win: The score matrix never goes to HBM. It’s computed inside SRAM (in B_q × B_k tiles), softmaxed inside SRAM, multiplied by V inside SRAM, with only the small running output O going to HBM. HBM accesses for the dominant intermediate drop from 4N² to ~2N²/B_q · d. For B_q = 128 and d = 64, that’s a factor of 2 lower; for typical FlashAttention-2 with B_q · d ≈ 16K, the reduction is 8-16×.

Why this gives wall-clock speedup: Reducing HBM traffic by 8× when you’re HBM-bandwidth-bound = ~8× wall-clock speedup. Compute stays the same (still N²d total FLOPs); compute was never the bottleneck. The win is entirely from “not transferring data through the slow memory tier.”

The deeper insight (FlashAttention’s paper title): the algorithm is “IO-aware.” It’s not “smarter math” — the math is identical to naïve attention. It’s “smarter data movement” — the same math, organised to live in SRAM instead of HBM. This is the lesson for any high-performance kernel: it’s not what you compute, it’s what you move.

FlashAttention-2, FlashAttention-3, and what’s next

Dao 2023 “FlashAttention-2” kept the algorithm and refined the implementation:

Result: another 2× over FlashAttention-1, getting attention to ~80% of theoretical peak FLOPs on A100.

Shah 2024 “FlashAttention-3” targets H100 specifically:

Result: ~75% of theoretical peak on H100 in fp16 (vs 35% for FA-2), and 1.2 PFLOPS in fp8 (vs 740 TFLOPS for FA-2 fp16).

This whole sequence — naïve → FA-1 (algorithm change) → FA-2 (kernel polish) → FA-3 (hardware-specific) — is the pattern of every important ML kernel evolution. Algorithmic insight (the online softmax + running O) unlocks the optimisation; then 2-3 years of kernel-polishing extracts the remaining 2-3× from the same algorithm on increasingly specific hardware. Understanding the algorithm is what lets you read the kernel code; understanding the kernel is what lets you debug performance on new hardware. Ch.20 (Hardware) and Ch.22 (Inference) revisit this from the systems angle.

Closing the loop on Part III

That’s it for Part III. We started with embeddings (Ch.11), moved through softmax and online softmax (Ch.12), and assembled attention with full multi-head and FlashAttention (Ch.13). Everything from here uses what’s been built:

The math is mostly done. The systems engineering, the hardware nuances, and the recent research are what Part IV through Part VI cover.

END OF CH.13 — Attention, fully assembled.
§1 (QKV as soft dictionary lookup, 1/√d_k scaling, N² cost) · §2 (multi-head, KV cache, MQA → GQA) · §3 (FlashAttention — tile + online softmax + running output, bit-equal to naïve).

END OF PART III. The transformer is assembled.

Coming next: Ch.14 — LayerNorm, RMSNorm, why they go where they go.