TRAINING AT SCALE
Section 24.3
03

Context parallelism + the multi-node systems problem

By 2024, frontier models train with 128K+ context windows. Some research targets 1M+ tokens. At those lengths, even FlashAttention is constrained by single-GPU memory — the activations for one long sequence take more memory than a 70B model’s weights. Context parallelism (CP) is the parallelism axis that splits the SEQUENCE dimension across GPUs. Ring Attention (Liu 2023) combines context parallelism with FlashAttention’s tile-and-stream pattern to enable training at unprecedented context lengths. This section walks CP, then steps back to summarise the multi-node systems-engineering picture and close Part V.

Why context parallelism is needed

The activation memory for one transformer layer at sequence T:

Per-layer activation memory at sequence T (decoder-only): QKV projection input/output: 4 · N · d · 2 bytes = 8 · N · d. Attention score matrix (or stream): N² · 2 bytes (avoided by FlashAttention but intermediate tiles still need SRAM). Attention output: N · d · 2 bytes. FFN intermediate (4× expansion): 4 · N · d · 2 bytes. FFN output: N · d · 2 bytes. Approximate total per layer: ~20 · N · d bytes for forward activations. For Llama 3 70B at N = 1M, d = 8192: Per layer: 20 · 10^6 · 8192 = 164 GB per layer. × 80 layers (with checkpointing — store only ~10): ~1.6 TB just for activations. Memory budget on H100: 80 GB. Doesn't fit. Need to SPLIT THE SEQUENCE.

The fix: split the sequence into chunks, each on a different GPU.

Ring Attention

Liu 2023 “Ring Attention with Blockwise Transformers for Near-Infinite Context” introduced the canonical context parallelism algorithm.

Ring Attention setup: CP_size GPUs in a logical ring. Sequence of N tokens split into CP_size CHUNKS of N/CP_size tokens each. Each GPU holds: its chunk of Q, K, V plus all model weights. Attention computation: every Q chunk needs to attend to every K, V chunk. The "ring" approach: - Each GPU holds Q chunk i, K chunk i, V chunk i locally. - Each GPU computes Q_i · K_i^T (its local diagonal attention). - Send K_i, V_i to the NEXT GPU in the ring; receive K_{i-1}, V_{i-1} from previous. - Compute Q_i · K_{i-1}^T, accumulate. - Continue for CP_size - 1 rounds until all (Q_i, K_j, V_j) pairs are computed. At each step, use FlashAttention's online softmax recurrence (Ch.12 §3) to combine the new (K, V) chunk's attention contribution with the running result. Memory per GPU: only own Q, K, V chunk (N/CP_size · d) plus model weights. For CP=8 and N=1M: per-GPU sequence chunk is 125K tokens — much more manageable. Communication: K, V tensors rotate around the ring. Per step: send/receive ~N/CP_size · d. Total comm per layer: (CP_size - 1) · (N/CP_size · d · 2) ≈ N · d · 2 bytes per direction. For N=1M, d=8192, CP=8: ~8 GB per layer of cross-GPU traffic. At NVLink 900 GB/s: ~9 ms per layer. Per forward + backward: ~14 ms × 80 layers = 1.1 s extra communication per step. Significant but tractable.

Context parallelism is the newest parallelism axis. It’s specifically needed when sequence length is the binding memory constraint, which is increasingly common in 2024+ models.

— think, then check —

Setup:

8 GPUs in a logical ring. Sequence of 1M tokens split into 8 chunks of 125K each.

GPU i holds Q_i (125K × d), K_i (125K × d), V_i (125K × d) — its local chunk.

Goal: each Q_i must attend to all 8 K, V chunks, eventually producing attention output for its 125K tokens.

Round 0 (local):

GPU i computes Q_i · K_i^T (its local diagonal block of the attention matrix).

Apply softmax (with running statistics — the online softmax from §12.3, since later rounds will contribute more).

Compute Q_i · K_i^T → weights → weights · V_i → partial output O_i.

Maintain (m, ℓ, O) running state per query row.

Rounds 1 through 7:

Step 1: GPU i sends its K_i, V_i to GPU (i+1) mod 8. Receives K_(i-1), V_(i-1) from previous GPU.

Step 2: Compute Q_i · K_(i-1)^T → partial scores. Update the running (m, ℓ, O) with this contribution (using the §12.3 online softmax recurrence).

Repeat: rotate K, V chunks around the ring. After 7 rotations, every GPU has computed attention from its Q to every K, V chunk in the sequence.

Final:

Each GPU i has O_i = attention output for its 125K queries. Divide by ℓ to normalise.

Output Y_i = O_i / ℓ_i. Each GPU holds its slice of the output sequence.

Memory per GPU during computation:

  • Own Q, K, V chunks: 3 × 125K × d × 2 bytes = ~6 GB (for d=8192).
  • One additional K, V pair being rotated: 2 × 125K × d × 2 = ~4 GB.
  • Running state (m, ℓ, O): 125K × d × 2 bytes ≈ 2 GB.
  • Total: ~12 GB per GPU for attention. Plus model weights, etc.

This is much less than the 164 GB per layer that would be needed without CP.

Communication cost:

Per layer per direction: ~125K × d × 2 bytes × 8 rotations = ~16 GB of K, V traffic.

Per layer (forward + backward): ~32 GB.

At NVLink 900 GB/s: ~35 ms per layer. Per step (80 layers): ~3 seconds.

That’s significant overhead. Modern implementations overlap K, V transfer with compute, reducing effective overhead to ~30%.

For 1M-context training, this is the price paid. CP is the only way to fit; communication is the price.

How everything combines — Llama 3 / DeepSeek V3 in practice

Llama 3 70B 128K-context training (rough config): DP = 16 (across nodes, batch parallelism) PP = 4 (4-stage pipeline within node groups) TP = 8 (within node, tensor parallel each attn/FFN layer) CP = 4 (split 128K sequence into 4 × 32K chunks) Total = 16 · 4 · 8 · 4 = 2048 GPUs per model copy. With DP across model copies: more if you have more GPUs. DeepSeek V3 671B MoE training: DP = X PP = 8 TP = 1 (just within-node sharding of expert weights and attention) EP = 8 (expert parallel within node) CP = 1-2 (depends on sequence length) Each parallelism axis has its OWN bandwidth requirements that map to the cluster topology. Designing the configuration is a multi-dimensional optimisation problem solved by internal lab tools.

Closing Part V — the systems engineering picture

We’ve covered the full systems stack of running LLMs:

The big picture: a frontier LLM is a systems engineering project as much as it is a machine learning project. The math (Parts I-III) was the foundation. The architecture (Part III/IV) is the model. But what TURNS BOTH into something that costs $100M to train and runs at 1000 tokens/sec for users is the systems work covered in Part V.

— think, then check —

The challenge:

Each GPU sees only 1/CP of the K, V tensors at a time. It needs to compute attention scores and apply softmax — but softmax requires the maximum and sum across the FULL sequence of scores.

Naive approach: gather all K, V to one GPU, compute softmax, scatter back. Memory infeasible (defeats CP).

Alternative: each GPU computes softmax just over its local block. But then attention is no longer correct (uses local max instead of global).

The online softmax fix:

Each GPU maintains running (m, ℓ, O) per query row across the sequence of K, V chunks it processes:

  • m: running max of scores seen so far for this row.
  • ℓ: running sum-of-exp normalized to m.
  • O: running weighted sum of V.

When a new K, V chunk arrives from the ring:

1. Compute partial scores Q_local · K_chunk^T.

2. Get the local max m_chunk and local sum ℓ_chunk over the chunk.

3. Apply the online recurrence:

m_new = max(m, m_chunk)

α = exp(m - m_new) (rescale factor for old state)

β = exp(m_chunk - m_new) (rescale factor for new contribution)

ℓ_new = α · ℓ + β · ℓ_chunk

O_new = α · O + β · (P_chunk · V_chunk)

4. Update m, ℓ, O.

After processing all CP_size chunks: the running (m, ℓ, O) is mathematically identical to what you’d compute with all K, V in one place.

Final attention output: Y = O / ℓ.

What would fail without online softmax:

You’d compute softmax over each K, V chunk in isolation. The weights would be normalised to the LOCAL maximum of that chunk, not the global maximum. Combining them after the fact would require careful re-normalisation that’s mathematically equivalent to the online recurrence.

Without the online recurrence, you’d need to:

  • Either materialise ALL scores then softmax (memory infeasible).
  • Or use approximations (correctness lost).

The online softmax is the bridge that makes “compute softmax in pieces” mathematically equivalent to “compute softmax over the whole.” Without it, Ring Attention can’t be exact.

The full assembly:

Ring Attention is: Context Parallelism (sequence split) + Ring rotation (K, V transfer pattern) + Online Softmax (mathematical correctness across chunks) + FlashAttention (efficient SRAM-staged computation within each chunk).

Each piece comes from a different paper, but they compose cleanly because the online softmax recurrence is associative (Ch.12 §3). The Ch.12 work — what felt like a small efficiency trick — became the structural enabler of training models with 1M+ token contexts.

— think, then check —

1. Hardware budget:

  • 100B model in mixed-precision needs ~1.6 TB total state.
  • At Chinchilla-optimal D = 20·N = 2T tokens → ~1.2·10²³ FLOPs.
  • On H100 at 50% MFU: ~50 PFLOPs/s sustained → ~14 weeks on 1000 H100s.
  • Budget: ~$10M of compute. + engineering team + data pipeline + storage.

2. Cluster topology:

  • 1024 H100s = 128 nodes × 8 GPUs/node.
  • NVLink within node (900 GB/s); InfiniBand between (50 GB/s).
  • Need to map parallelism to bandwidth hierarchy.

3. Parallelism configuration:

  • TP = 8 (intra-node, splits each matmul). Communication frequent → needs NVLink.
  • PP = 4 (split layers across 4 node-groups). Communication per-microbatch → can use InfiniBand.
  • DP = 1024 / (8 · 4) = 32 (data parallel replicas). Per-step all-reduce → InfiniBand acceptable.
  • Microbatches K = 64-128 for good PP efficiency.
  • Effective global batch: 32 × 64 microbatches × 32K tokens = 64M tokens per step.

4. Memory optimisation:

  • FSDP / ZeRO Stage 1 within each TP × PP unit (saves Adam state).
  • Activation checkpointing (recompute during backward instead of storing).
  • bf16 training with fp32 master copy.

5. Data pipeline:

  • 2T tokens of high-quality data — FineWeb + curated additions (Ch.16).
  • Tokenization, dedup (MinHash), quality classification.
  • Streaming dataloader: shards across DP replicas; no full data ever loaded.

6. Optimization recipe:

  • AdamW with β1=0.9, β2=0.95, weight decay 0.1.
  • Linear warmup (1K steps) → cosine decay over remaining steps.
  • Peak LR ~1.5e-4 for 100B model.
  • Gradient clipping at 1.0.

7. Monitoring and recovery:

  • Checkpoint every ~30-60 minutes (state too big to checkpoint more often).
  • Monitor loss curves, gradient norms, learning rate.
  • Have plan for OOM / NCCL hangs / GPU failure (~5-10 GPU failures per week at this scale).
  • Use a tool like W&B or Tensorboard for distributed metrics.

8. Hyperparameters at scale:

  • Maximum learning rate often empirically derived from smaller-scale runs (Hoffmann scaling).
  • Warmup length proportional to total steps.
  • RoPE base scaling for long contexts.

9. Inference plan:

  • At end of training, you have a 100B fp16 model = 200 GB.
  • For serving: quantize to int8 or int4 for cheaper deployment.
  • Continuous batching, PagedAttention, speculative decoding.
  • Or distil to a smaller model for cheaper inference.

10. The whole picture:

From the math up through the systems: every layer of the stack matters. The 100B-token training is bottlenecked by communication primitives (Ch.21), memory hierarchy (Ch.21), framework efficiency (Ch.22), inference economics (Ch.23), parallelism choices (Ch.24), and quantization for deployment (Ch.25).

The trajectory of LLM scale (1B → 7B → 70B → 671B) has been enabled by systems-engineering breakthroughs as much as by ML breakthroughs. FlashAttention enabled long context; ZeRO enabled large models; PagedAttention enabled serving economics; speculative decoding enabled latency; quantization enabled deployment. Each was an enabler that unlocked the next scale.

This is the world a senior systems engineer is moving INTO. The math wasn’t optional; it was the foundation. The systems engineering is where the work happens.

END OF CH.24 — Training at scale.
§1 (DP + ZeRO/FSDP: sharding the bookkeeping, the all-reduce cost) · §2 (TP + PP + EP: 3D parallelism grid, the pipeline bubble) · §3 (CP + Ring Attention: 1M+ token training, closing the systems picture).

END OF PART V — The Systems That Run Them. Hardware, runtimes, inference, training, quantization. The mathematics we built up in Parts I-IV is RUNNABLE at scale because of these systems contributions. The book’s technical content is now complete.

Coming next: Part VI. Ch.26 — Vector search and ANN, including the rotation-based quantization that started the conversation that began this book. Ch.27 — Reading research like a researcher, and the close.