QUANTIZATION IN PRACTICE
Section 24.3
03

Quantization-aware training — STE, BitNet, QLoRA

§24.1 and §24.2 covered post-training quantization — train in fp16, then quantize for deployment. That gets you to ~4 bits with modest perplexity drop. Past 4 bits, PTQ degrades sharply and you need a different approach: quantization-aware training (QAT) puts the quantizer inside the training loop. The forward pass uses quantized weights; gradients flow through as if the quantizer didn’t exist (the straight-through estimator). The network learns weights that are quantization-friendly. This unlocks 2-bit and even 1-bit weights with much smaller quality loss. The same trick — STE through a non-differentiable operation — also powers BitNet (1.58-bit models trained from scratch) and QLoRA (the technique that lets you fine-tune Llama 2 70B on a 24GB consumer GPU). This section closes the chapter by walking the math, the kernel, and the production techniques.

The straight-through estimator (STE)

The quantization operation Q(W) = scale · round(W / scale) has zero derivative almost everywhere (it’s a step function) and undefined derivative at the step boundaries. Plain backprop through it would produce no gradient. The straight-through estimator sidesteps this by pretending in the backward pass that Q is the identity:

Forward pass: W_q = Q(W) (quantize: step function) Y = X · W_q (matmul uses quantized weights) Backward pass: ∂L/∂W_q = X^T · ∂L/∂Y (standard matmul gradient) ∂L/∂W = ∂L/∂W_q ← STE: pretend Q is identity Equivalently: ∂Q/∂W ≈ 1 (everywhere)

The STE is “wrong” — Q’s true derivative isn’t 1. But it’s wrong in a way that works empirically: the gradient still points in roughly the right direction (small perturbations of W produce small perturbations of Q(W) on average, even though pointwise it jumps), and the network learns weights that, when quantized, give good outputs.

The fake-quantize-in-forward + STE-in-backward pattern is the entire idea behind QAT. The kernel below trains a 2-layer MLP two ways and measures the deployment cost:

qat_ste.c — fake_quantize_int4 C · QAT vs PTQ on a tiny MLP
/* Symmetric int4 fake-quantize (per-block of QK) */
static void fake_quantize_int4(const float* W, float* W_fq, int n) {
    for (int b = 0; b < n; b += QK) {
        int end = b + QK; if (end > n) end = n;
        float amax = 0;
        for (int i = b; i < end; i++) {
            float a = fabsf(W[i]); if (a > amax) amax = a;
        }
        float scale = amax / 7.0f;
        if (scale == 0) scale = 1.0f;
        for (int i = b; i < end; i++) {
            int q = (int)roundf(W[i] / scale);
            if (q >  7) q =  7;
            if (q < -8) q = -8;
            W_fq[i] = q * scale;
        }
    }
}

/* Generate a synthetic regression dataset: y = sin(sum(x)) + noise. */

Output:

Mode 1: train in fp32 (no QAT)
  fp32  ep   0  loss = 0.79713
  fp32  ep 100  loss = 0.37293
  fp32  ep 199  loss = 0.34471

Mode 2: train with QAT (forward pass uses Q(W) via STE)
   qat  ep   0  loss = 0.80807
   qat  ep 100  loss = 0.37556
   qat  ep 199  loss = 0.34672

--- Final eval ---
Model              fp32 deploy   int4 deploy   gap
trained fp32       0.34449       0.35018       +1.65%
trained QAT (STE)  0.34799       0.34424       -1.08%

Read carefully: the fp32-trained model has loss 0.344 in fp32 deployment, 0.350 in int4 — a 1.65% degradation. The QAT-trained model has slightly worse fp32 loss (0.348) but essentially zero degradation when deployed in int4 (0.344). On the deployment metric (int4), QAT wins. It traded a tiny bit of fp32 quality for a model whose weights are pre-adapted to be quantized.

— think, then check —

The mathematical problem:

The quantizer Q(W) = scale · round(W / scale) is a step function. Its derivative is 0 almost everywhere and undefined at the step transitions. A pure chain-rule backprop through it would produce zero gradient and the network couldn’t learn.

What STE does:

Pretend Q’s derivative is 1 in the backward pass. The forward uses Q(W); the backward computes gradients as if Q were just the identity.

This is provably “wrong” — Q is not the identity. But for the purpose of training, it gives a workable signal.

Why it works:

(1) On AVERAGE, Q approximates the identity. For small enough perturbations, E[Q(W + δ)] ≈ Q(W) + δ — the expectation of a small perturbation to W after quantization is close to the same small perturbation. So while the pointwise derivative is wrong, the expected derivative behaves like 1.

(2) The loss landscape isn’t sharp. Quantization noise produces local fluctuations in the loss, but the BROAD landscape (averaged over many quantization-boundary crossings) is smooth. The STE gradient points in the right direction in this smoothed sense.

(3) The fp32 “shadow” weights matter, not the gradient at the quantization boundary itself. The optimiser updates the fp32 weights W; the quantized weights Q(W) are derived. As W moves smoothly through space, Q(W) jumps occasionally — but the fp32 W’s trajectory is what learns, and STE gives a reasonable signal for that trajectory.

(4) Empirically. Bengio 2013, Hinton 2012 lectures, and a decade of follow-up work showed STE is the workhorse for training low-bit networks. The “wrong” gradient is good enough.

The deeper insight: STE is a particular instance of a broader pattern in deep learning: when a forward operation is non-differentiable, replace its backward pass with the identity (or a smooth surrogate). The same trick is used in Gumbel-Softmax (for discrete sampling), in REINFORCE (for sampling actions), in argmax-attention (for hard attention). All “wrong” gradients that work.

Learned Step Size Quantization (LSQ)

A refinement: instead of using a fixed scale per block, make the scale a learnable parameter. The optimizer trains both the weights and the per-block scale. Esser 2019 “Learned Step Size Quantization”:

LSQ adds the scale s as a parameter: x_q = round(W / s) clipped to integer range W_q = s · x_q (forward) ∂L/∂W = ∂L/∂W_q (STE on the round) ∂L/∂s = ∂L/∂W_q · ∂W_q/∂s (real gradient for s) = ∂L/∂W_q · (x_q - W/s · ∂round/∂(W/s)) ≈ ∂L/∂W_q · x_q (STE again on round) so: ∂L/∂s ≈ Σ over block of (∂L/∂W_q[i] · x_q[i])

LSQ lets the model discover the optimal scale per block during training, instead of using the absmax heuristic. Typical gains: 0.2-0.5 perplexity over absmax-based QAT at the same bpw. Standard in modern QAT pipelines for sub-4-bit quantization.

BitNet — 1-bit (and 1.58-bit) from scratch

Wang 2023 “BitNet” took QAT to its extreme: 1-bit weights, trained from scratch.

BitNet's weight quantizer: W_q = sign(W) ∈ {-1, +1} (1-bit, no scale at all per weight) plus per-tensor scale α = mean |W| (one fp16 number for the whole tensor) Forward: Y = X · (α · sign(W)) (matmul = additions and subtractions; no fp multiplies!) Backward: STE on sign → ∂L/∂W = ∂L/∂Y · α · X^T

A 1-bit weight matrix has no per-element scale and only one sign bit per weight. The matmul Y = X · W becomes a sequence of additions and subtractions — no multiplications. On custom hardware, this is dramatically cheaper than even int8 matmul.

Ma 2024 “BitNet b1.58” refined this to three states 1 — 1.58 bits per weight (log₂ 3 ≈ 1.585). The third “0” state lets the network gate connections, which empirically recovers most of the quality gap to fp16. BitNet b1.58 reportedly matches fp16 Llama at sizes ≥ 3B params on standard benchmarks, with 8× memory reduction and similar compute reduction on hardware that exploits the ternary structure.

BitNet b1.58 is the strongest evidence that quantization-aware training can fundamentally change the cost / quality frontier — not just compress an already-trained model, but produce a fundamentally cheaper model that’s as good.

The catch: BitNet has to be trained from scratch. You can’t take a Llama 3 70B fp16 checkpoint and convert it to 1.58-bit without losing massive quality — the fp16 weights aren’t in a 1.58-bit-friendly configuration. Training from scratch costs the same as training fp16 (because the activations are still fp16; only weights are ternary). So BitNet is a deployment win, not a training win.

— think, then check —

The crucial structural difference:

fp16 training settles into a weight distribution with continuous-valued weights spanning a wide range of small magnitudes — a near-continuous Gaussian shape with std ~0.02. This distribution has nothing in common with 1. Rounding it to ternary is catastrophic.

1.58-bit training-from-scratch with STE settles into a fundamentally different weight distribution. Throughout training, the model’s effective forward pass uses ternary weights — so the network LEARNS to encode information in ternary form. The continuous “shadow” weights that the optimizer updates are constantly being rounded; they evolve toward configurations that ROUND to good ternary patterns.

What changes in the loss landscape:

fp16 training: gradient descent in a continuous space, exploring a Gaussian-shaped weight distribution.

1.58-bit QAT: gradient descent in the same continuous space BUT the loss is computed using ternary weights. The optimizer learns to find ternary-aligned local minima — points where small perturbations of the shadow weights don’t change the ternary result, but the ternary configuration is locally optimal.

Why this works at all:

The Lottery Ticket Hypothesis (Frankle 2018) and follow-up work showed that sparse / ternary subnetworks WITHIN a dense network can match the dense network’s performance. BitNet finds these structures directly by constraining the search space to ternary throughout training, instead of trying to find them after the fact via pruning + quantization of an already-trained dense model.

Why PTQ fails:

fp16 weights and ternary weights live in different “neighborhoods” of weight space. There’s no smooth path from a typical fp16 weight value (e.g., 0.0184) to its rounded ternary equivalent (0) — the rounding throws away the value. PTQ assumes the two regimes are close; they’re not, for 1.58-bit. QAT navigates the network to a ternary-friendly region of weight space directly.

The training cost: BitNet has to be trained fully from scratch — you cannot “convert” Llama 3 to BitNet b1.58 without retraining. So the win is at deployment (8× memory, much cheaper compute) but the upfront training investment is the same as fp16.

QLoRA — the production trick

The most impactful application of QAT-adjacent techniques is QLoRA — Dettmers 2023 — the technique that put fine-tuning of 70B models within reach of solo researchers.

QLoRA fine-tunes a model with FROZEN 4-bit base weights and fp16 LoRA adapters: Layer's effective weight at fine-tuning time: W_eff = dequantize(W_base, NF4) + B · A ↑ ↑ 4-bit, frozen, no gradient fp16, trainable, small rank r Forward pass: Y = X · W_eff = X · dequantize(W_base) + X · (B · A) Backward pass: ∂L/∂W_base = none (W_base is frozen — no update) ∂L/∂A = B^T · (X^T · ∂L/∂Y) (standard matmul backward) ∂L/∂B = (X^T · ∂L/∂Y) · A^T Critically: gradients flow THROUGH the dequantized W_base (it's not a barrier), but DON'T modify it (it's a constant during fine-tuning).

The data type used for the base is NF4 (NormalFloat 4-bit), which uses 16 quantization levels positioned at the quantiles of a normal distribution. Since LLM weights ARE approximately normally distributed, NF4 represents them more accurately than evenly-spaced int4. The choice is matched to the empirical weight distribution.

Why this works as well as fine-tuning the full fp16 model:

  1. The LoRA adapters B · A have full fp16 precision, so they can represent any necessary correction to W_base.
  2. The frozen 4-bit base is “good enough” — its quantization error is in directions the LoRA adapter can correct.
  3. The base + adapter combination at inference is effectively a higher-precision matrix than 4-bit alone.

The memory math is brutal in QLoRA’s favor. A 70B model:

Full fp16 fine-tuning: Model weights: 70B · 2 bytes = 140 GB Gradients: 70B · 2 bytes = 140 GB Adam state (m, v): 70B · 8 bytes = 560 GB Total: ≈ 840 GB — requires multi-GPU sharded training. QLoRA fine-tuning (rank-16 LoRA, ~1% of params trainable): Base weights: 70B · 0.5 bytes ≈ 35 GB (4-bit NF4) LoRA params: 0.7B · 2 bytes ≈ 1.4 GB (only fine-tuned) LoRA gradients: 0.7B · 2 bytes ≈ 1.4 GB Adam state: 0.7B · 8 bytes ≈ 5.6 GB Total: ≈ 43 GB — fits on a single H100 (80 GB) with room for activations.

QLoRA is the workhorse of single-GPU LLM customization. Every “fine-tune your own Llama” tutorial on Hugging Face goes through QLoRA.

— think, then check —

Tensors and dtypes at fine-tuning time:

For each linear layer L in the base model:

  • W_base — 4-bit NF4, frozen. Stored on GPU. Dequantized to fp16 on-the-fly inside the layer’s forward pass.
  • A — fp16, trainable. Shape (r, d_out). Small (r typically 16-64).
  • B — fp16, trainable. Shape (d_in, r). Initialized to zero so the LoRA adapter starts as a no-op.
  • m, v (Adam state) — fp16 or fp32, for A and B only.
  • Gradient buffers ∂L/∂A, ∂L/∂B — fp16. For A and B only.

Forward pass for layer L:

W_eff = dequantize_nf4(W_base) + B @ A (W_eff in fp16, held briefly)

Y = X @ W_eff (matmul in fp16)

The dequantized W_base is materialized briefly, used for the matmul, then discarded. Activation X is checkpointed for backward.

Backward pass for layer L:

Incoming: ∂L/∂Y, X.

∂L/∂W_eff = X^T @ ∂L/∂Y (the gradient w.r.t. the effective weight — straightforward matmul backward)

∂L/∂A = B^T @ ∂L/∂W_eff (chain rule through W_eff = … + B @ A)

∂L/∂B = ∂L/∂W_eff @ A^T

∂L/∂W_base = ∂L/∂W_eff ← computed BUT discarded (W_base is frozen)

∂L/∂X = ∂L/∂Y @ W_eff^T (passed back to previous layer)

What’s critical:

  1. The 4-bit W_base is NEVER updated. We only read it during forward (and to compute ∂L/∂X during backward), but we don’t track gradients for it.
  2. The dequantization happens on-the-fly during forward — no fp16 copy of W_base is stored long-term. This is the memory win.
  3. The LoRA adapters A, B are tiny (rank r ≈ 16) but full-precision (fp16). They learn to compensate for both the base model’s deficiencies on the new task AND the quantization noise in W_base.
  4. The optimizer only updates A and B. Adam state, gradients, and parameter copies are needed only for these — typically <1% of the parameter count.

Why this matches full fine-tuning quality:

The LoRA adapters can express the same range of weight modifications as full fine-tuning at the granularity that matters for the fine-tuning task (low-rank updates capture most task-specific changes). The 4-bit base introduces ~0.04 perplexity drop vs fp16 base, which is recovered by the LoRA adapters. Net: QLoRA fine-tunes match full fp16 fine-tunes within ~0.1 perplexity at 1% the memory cost.

QLoRA is the most important deployment-side training technique of the post-Chinchilla era. Every Llama derivative on Hugging Face that says “fine-tuned” was probably QLoRA-fine-tuned.

The picture across the chapter

Three layers of quantization, three different scopes:

Train Quantize Use case PTQ: fp16 after train Deploy a pre-trained model in less memory QAT: with Q(W) during train Get past PTQ's quality floor (≤3 bits) in forward, at the cost of full retrain or fine-tune STE backward BitNet: 1.58-bit from scratch Train a fundamentally cheaper model in forward, if you control the training run STE backward QLoRA: 4-bit base base PTQ-d Fine-tune a quantized base on a single GPU frozen, once; fp16 adapters adapters trained trained fp16

Where each lives in production:

END OF CH.24 — Quantization.
§1 (PTQ basics, blockwise vs per-tensor, outliers, LLM.int8 → GPTQ → AWQ) · §2 (the GGML family: q4_0 to q6_K with exact byte layouts, q4_K_M model naming, IQ-quants + imatrix) · §3 (QAT + STE, BitNet b1.58, QLoRA — the workhorses of low-bit training and single-GPU fine-tuning).

Coming next: back to the normal chapter order. Ch.16 — Pretraining, Chinchilla scaling laws, the token budget.