Tiled GEMM microkernel
Section 2.2’s naïve matmul is the slowest fast operation on your machine. Three nested loops, all the right arithmetic, every entry of C a textbook dot product — and somewhere between one and two orders of magnitude slower than what the silicon will let you do. The gap isn’t because the compiler missed FMA; modern compilers vectorize the inner loop fine. The gap is that naïve matmul streams B from main memory once per row of C. Memory bandwidth caps you out long before the arithmetic does. Tiling — keep a small working set in fast memory, reuse it before moving on — is what closes the gap. This section is also the architectural ancestor of FlashAttention: a tiled inner kernel surrounded by a loop nest. Same shape, same reasons.
Why naïve is slow: arithmetic intensity
Count flops vs bytes. A square N × N × N matmul does 2N³ floating-point operations (one multiply and one add per inner step). The inputs and output occupy 3 N² floats — that is, every entry of A, B, and C must touch memory at least once. The ratio:
For N = 1024 that’s about 680 FLOPs per float. For N = 64, it’s about 40. The number gets better with larger N — which sounds great, until you ask whether you can actually realise it. Modern CPUs deliver ~10–100 FLOPs per byte of memory bandwidth on the roofline; if a kernel achieves less than that, something else is the bottleneck.
The catch: naïve matmul doesn’t achieve its theoretical intensity. The inner loop in gemm_naive reads one element of A and one element of B, multiplies, adds, moves on. By the time you get back to that element of B for the next row of C, it’s no longer in cache — and you re-fetch from main memory. Realised intensity collapses to O(1): a couple of FLOPs per byte. You become memory-bound.
What tiling buys
The fix is to refactor the loop nest so that whenever you bring a tile of B into cache, you do as much work as possible on it before evicting. Mechanically: chop B (and matching parts of A and C) into tiles small enough to live in L1, then iterate. Each tile of B gets reused m times instead of once.
This is the K-axis tiling from §2.2 applied for a systems reason. Hop back to that section: K-tiling splits the contraction axis into chunks; each chunk produces a partial output. The partial outputs get summed (which is fine — addition is associative). The cost we noted in §2.2 (K-tiling needs an extra reduction) is now a feature: by chunking K, we can keep a small slab of B resident while doing many C updates against it. M-tiling and N-tiling, similarly, help by keeping the working set per tile inside cache.
In practice fast BLAS kernels (GotoBLAS, OpenBLAS, BLIS, MKL) tile three levels deep — register, L1, L2 — each level matching a tier of the memory hierarchy:
The innermost level — the microkernel — is what touches the SIMD intrinsics. Everything outside is just nested loops. That separation is the reason BLIS can support fifteen ISAs (AVX, AVX-512, NEON, SVE, POWER, MIPS…) with one or two hundred lines of ISA-specific code: only the microkernel changes per architecture; the outer tiling loops are portable.
The microkernel, traced
The viz steps through one row of C being computed by the AVX2 microkernel. Each K step broadcasts a single element of A across 8 SIMD lanes (the _mm256_set1_ps instruction), loads 8 contiguous elements of B, and FMA-accumulates into the register-resident chunk of C:
Three things to feel:
- The accumulator lives in a register, not in memory. Eight floats of C stay in one
__m256across the entire K loop. Only at the very end does the store happen. This is the same “carry partial results in lanes, reduce/store once” pattern as §1.2 — the dot product — generalized: now there are 8 independent dot products being accumulated in parallel, one per N-lane. - A is broadcast; B is vectorized. Each K step uses only one element of A but eight of B. The asymmetry isn’t accidental — it’s what makes the kernel SIMD-friendly given row-major layout. (For column-major A and row-major B, the asymmetry flips.)
- The K loop touches B in stride-1 order. Each K-step reads
B[p][j..j+8]— 8 consecutive floats, one cache-line. That’s why the kernel runs near memory’s sequential-access speed when the data fits in cache, instead of being throttled by random-access latency.
Read the AVX kernel and the NEON twin side by side: they’re the same kernel modulo lane count. AVX2 fits 8 lanes per register, so the inner block is 8-wide; NEON fits 4, so 4-wide. The broadcast (_mm256_set1_ps / vdupq_n_f32), the vector load (_mm256_loadu_ps / vld1q_f32), and the FMA (_mm256_fmadd_ps / vfmaq_f32) translate one-for-one. The microkernel concept is portable; the dialect isn’t (Ch.1 §1.4 covered this).
* un-tiled version and called per-K-block by the tiled one. */
static inline void
row_block_8(const float* a_row, const float* B, float* c_row,
int k, int n, int j) {
__m256 c = _mm256_loadu_ps(c_row + j); /* current C accumulator */
for (int p = 0; p < k; p++) {
__m256 a = _mm256_set1_ps(a_row[p]); /* broadcast A[i][p] across 8 lanes */
__m256 b = _mm256_loadu_ps(B + (size_t)p * n + j); /* B[p][j..j+8] */
c = _mm256_fmadd_ps(a, b, c); /* c += a · b, 8 FMAs in one instr */
}
_mm256_storeu_ps(c_row + j, c);That little function is the workhorse — 11 lines, one FMA per K-step, accumulator in a register. The outer gemm_micro_avx is just a loop nest over (i, j) calling it; the K-tiled version gemm_tiled_avx is one more loop on the outside that chunks K so a slab of B stays resident:
* outer K-tile loop is the K-axis tiling from §2.2. For each K-tile we
* call the same vector microkernel on every (i, j) — which now reuses
* the tile of B that was just touched, instead of streaming all of B
* from main memory for every row of C. */
void gemm_tiled_avx(const float* A, const float* B, float* C,
int m, int n, int k, int kc) {
memset(C, 0, (size_t)m * n * sizeof(float));
if (kc <= 0) kc = 64; /* sane default */
for (int pc = 0; pc < k; pc += kc) {
int kb = (pc + kc <= k) ? kc : (k - pc); /* this tile's K extent */
for (int i = 0; i < m; i++) {
const float* a_row = A + (size_t)i * k + pc; /* row strip of A */
float* c_row = C + (size_t)i * n;
int j = 0;
for (; j + 8 <= n; j += 8) {
row_block_8(a_row, B + (size_t)pc * n, c_row, kb, n, j);
}
for (; j < n; j++) {
float s = 0.0f;
for (int p = 0; p < kb; p++) s += a_row[p] * B[(size_t)(pc + p) * n + j];Now run it
The benchmark (bench_gemm.c) does a 384×384×384 matmul three ways: naïve scalar, vectorized microkernel without blocking, and microkernel with K-blocking. On the Apple Silicon CI host:
benchmark: gemm N×N×N = 384 · K-tile = 64
wall (s) GFLOPS
naive (scalar, ijk) 0.0399 2.84
micro (SIMD, no blocking) 0.0111 10.20
tiled (SIMD, K-block= 64) 0.0065 17.49
all three agree.
speedup: micro vs naive = 3.59x ; tiled vs micro = 1.71x
Three things in those numbers:
- 3.6× from vectorization alone. Naïve to microkernel: same algorithm, same arithmetic, just held in a SIMD register through the K loop. The FMA throughput was always there; the naïve version couldn’t reach it.
- 1.7× from K-blocking on top. Even at N = 384 — where B is 384² × 4 = 576 KB, just over the L1 budget — blocking K to 64 keeps a slab of B in cache. At larger N, the gap widens dramatically. This is the memory-hierarchy speedup, not a compute one.
- Still ~10–20% of peak. Apple Silicon’s NEON peak is ~50–100 GFLOPS per core for FP32; we’re well short. Closing that gap requires register-level tiling (MR × NR larger than 1×8), software pipelining, packing the K-strip of A into contiguous memory before each microkernel call — the next layer of optimization, which is BLIS-grade engineering and out of scope here. The point of §4 is not “build the world’s fastest gemm.” The point is why the structural pattern looks the way it does.
Intensity = 2N³ / 3N² = (2/3)N flops per float touched. Asymptotically O(N), so matmul should be highly compute-bound at any reasonable size — exactly the property that makes GPUs scream on it.
The naïve loop fails because it touches B by streaming through the same elements N times — once per row of C — and at each pass the elements have been evicted from cache and must be re-fetched from main memory. Realised intensity drops to O(1). Tiling fixes this by reusing each loaded tile of B many times before evicting it.
The forward connection — why this is the FlashAttention substrate
Attention’s hot path is two matmuls: S = QKᵀ and O = softmax(S) V. Both have the same structural shape as the matmuls we just tiled. The catch: the softmax in between is a row-wise reduction across the sequence length, and you can’t tile that axis associatively. FlashAttention’s contribution is exactly: use the gemm microkernel pattern, but carry the softmax running statistics across tiles. The architecture is gemm-microkernel + online-softmax bookkeeping.
So everything in §2.4 — register tiles, the K-block loop, the broadcast-and-FMA inner kernel — is the substrate you’ll inherit unchanged in Ch.13. The difference is one extra piece of state per tile, not a different overall shape. Once you’ve internalised the microkernel pattern, FlashAttention is incremental.
The arc of Ch.2, looking back. §1 said matrices are functions. §2 said matmul is composition with three axes. §3 said orthogonal matrices preserve dot products. §4 said the tiled microkernel is what makes matmul fast — and that pattern is FlashAttention’s substrate.
Stand them next to each other: “matrices are functions” is the algebra; “three axes” is the structure; “orthogonal preserves” is the geometry; “tiled microkernel” is the systems. Every later chapter that uses linear algebra leans on one of these four pieces, and several lean on more than one. Attention is matmul (§1) composed (§2) and tiled (§4); rotation-based quantization is dot products (§3) re-scored against a compressed (Ch.24) function (§1). Same four pieces, recombined.
Without K-tiling, the inner loop touches the whole matrix B once per row of C. The slab of B you read for the first K step has been evicted from cache by the time you finish K=k–1, and you re-fetch it from main memory N times total.
K-tiling makes you process a slab of B (size kc × N) before moving to the next slab. That slab fits in L1; you reuse it M times before evicting it. The partial-sum recombination is just addition (+= into C), which is cheap — and the win in memory traffic pays for it many times over. The K-axis reduction being associative is what makes this trivial; it’s the same reason §1’s horizontal sum was easy. (FlashAttention’s reduction is not associative, which is why it needs running statistics — but the K-tiling skeleton is the same.)
The microkernel computes 8 columns of C at once — one SIMD register holds 8 partial sums, one per output column. At each K step we add A[i][p] · B[p][j..j+8] to that register. The 8 lanes of B are the 8 different columns of C we’re updating; they need to be vectorized. The single element of A is the same scalar for all 8 columns, so it gets broadcast.
If A were column-major: A[·][p] would be contiguous and could be loaded as a vector, while B[p][·] would be strided. The asymmetry would flip — vectorize the M axis (8 rows of C at once), broadcast B. The microkernel’s shape mirrors the memory layout; that’s why production BLAS includes both row-major (CBLAS) and column-major (FORTRAN BLAS) interfaces, with separate microkernels for each.
END OF CH.2 — Matrices as transformations.
§1 (matrices as functions) · §2 (matmul as composition, the three axes) · §3 (orthogonal matrices preserve dot products) · §4 (tiled microkernel — FlashAttention’s substrate).
All four sections compile and run in CI. The chapter’s nine recall items chain back through Ch.1 and forward to Ch.13 (FlashAttention) and Ch.25 (TurboQuant). Coming next: Ch.3 — Floating point, integers, and quantization error.