Vector Jacobians & the VJP
Ch.4 §3 gave you the chain rule and a 120-line scalar autograd. Every op had a closed-form derivative (one number per input → one number per output) and the backward pass was a switch statement that multiplied them. Now scale up. Real neural-network ops are vector → vector (matmul: ℝⁿ → ℝᵐ) or tensor → tensor (attention, normalisation). Each op’s derivative is no longer a number — it’s a Jacobian matrix of shape m × n (Ch.4 §2). And a typical neural network has Jacobians of shape (millions × millions), which would be catastrophic to store. The whole trick of backprop, the move that lets it scale to a 405B-parameter LLaMA, is to never construct those Jacobians at all. Instead, every op exposes a vector-Jacobian product (VJP) — a function that, given an output gradient, produces an input gradient directly. The VJP is structurally a few matmuls or elementwise ops; the full Jacobian is a giant matrix that gets contracted with the gradient anyway. We skip the materialisation, do the contraction in place, and pay only the cost of one extra matmul per layer in the backward pass.
Why we don’t construct the Jacobian
Pick a typical transformer linear layer: input dimension d = 4096, output dimension m = 4096. The Jacobian is m × d = 4096² = 16,777,216 entries. In float32 that’s 67 MB. Per layer. A 32-layer model would need 32 × 67 = 2.1 GB of just Jacobians if we tried to build them — and that’s for a single small batch entry, not a real production batch.
Worse: we wouldn’t actually use the Jacobian for anything except matrix multiplication. The chain rule (Ch.4 §3) says J(f∘g) = J(f) · J(g) — we compose Jacobians by matmul. The downstream gradient is \partial L/\partial \theta = \partial L/\partial y · J. We compute the gradient by multiplying the Jacobian against an upstream gradient vector — and the result is itself a vector.
So we never actually need J; we only need the function “apply J^T to a vector.” That function is the VJP:
For matmul, ReLU, softmax — every op a neural network is made of — the VJP has a clean closed form that’s cheaper than the matmul needed to apply the full Jacobian. We’ll derive the big three next.
For a linear layer with input dim d = 4096 and output dim m = 4096, the Jacobian is m × d = 16.7M entries = 67 MB in float32. A 32-layer model would need ~2 GB of just Jacobians per batch sample to store them all. For a real training batch (e.g. 1024 samples), that’s terabytes.
Worse: we never need the Jacobian itself — we only need its product with an upstream gradient vector (the backward pass through the chain rule). So we skip materialising J entirely. For each op type, we implement the VJP directly via a few matmuls or elementwise ops — the same total compute, vastly less memory.
This is why PyTorch’s torch.autograd.Function.backward takes a grad_output and returns a grad_input — never constructs the Jacobian. JAX’s jax.vjp exposes this directly. TensorFlow’s tf.gradients does the same under the hood. VJP-style backward is the universal idiom of modern autodiff.
The big three VJPs
Matmul VJP — the linear layer at the heart of every model:
The cost: one outer product (m · n ops) and one matmul (m · n ops). Total: 2 m n — same as the forward pass y = Wx (also m · n multiply-adds). So the backward pass of one linear layer costs ~2× its forward pass. This is why training is “~3× forward-pass cost” — one forward + one backward = ~3× one forward.
ReLU VJP — the workhorse nonlinearity (Ch.10 §2 has more):
Cost: O(n). The cheapest backward of any common op.
Softmax + cross-entropy VJP — the standard classification head, where the magic happens:
This is the beautiful one. The softmax Jacobian alone is a dense C × C matrix (∂pᵢ/∂zⱼ = pᵢ(δᵢⱼ − pⱼ)) — ugly to compute, expensive to multiply. But composed with cross-entropy, all the off-diagonal terms cancel. The composed VJP is just p - y. Three vectorised ops total (softmax forward, subtract from one-hot, no matrix construction).
This cancellation is why every framework fuses softmax + cross-entropy into one op (F.cross_entropy in PyTorch, cross_entropy_loss in JAX) — both for numerical stability (log-sum-exp trick) and to expose the clean p - y backward.
Toggle the viz between forward and backward. On the forward pass, values flow left→right; on the backward, gradients flow right→left, each node applying its VJP formula in place.
By the chain rule: ∂L/∂Wᵢⱼ = Σₖ (∂L/∂yₖ) · (∂yₖ/∂Wᵢⱼ). For y = Wx: ∂yₖ/∂Wᵢⱼ = δ_{ki} · xⱼ (only the entry of y at index i depends on Wᵢⱼ, and it depends on x[j]). So ∂L/∂Wᵢⱼ = (∂L/∂yᵢ) · xⱼ = vᵢ · xⱼ. In matrix form: ∂L/∂W = v · xᵀ (outer product, m × n).
For ∂L/∂xⱼ: ∂yₖ/∂xⱼ = Wₖⱼ. So ∂L/∂xⱼ = Σₖ vₖ · Wₖⱼ = (Wᵀv)ⱼ. In matrix form: ∂L/∂x = Wᵀv (length n).
Shapes: v is shape (m,), x is shape (n,), so v · xᵀ is shape (m, n) — matches W. Wᵀ is (n, m), v is (m,), so Wᵀv is (n,) — matches x. Each VJP “undoes” the forward shape transformation, producing a gradient of the same shape as its corresponding input.
How to read a VJP
The pattern that recurs for every op:
The “saved values” matter. For matmul’s VJP you need x (to form v · x^T) and W (to form W^T v). For ReLU’s VJP you need the input x (to know where it was positive). For softmax+CE you need p (to form p - y). Every op holds onto enough forward-pass state to compute its own backward. That state is the activation memory — and for deep networks, it’s often the dominant memory cost during training, ahead of the parameters themselves. (Gradient checkpointing — covered in Ch.23 — is the standard trick to reduce it by recomputing activations on demand.)
Why softmax + cross-entropy collapses
The clean p - y derivative for cross-entropy on a softmax output isn’t magic; it’s a consequence of cross-entropy being the matching loss for softmax. The general principle (look up canonical link function in any generalised-linear-models text): when your loss is the negative log-likelihood under the canonical link, the gradient w.r.t. the pre-link parameter is always prediction − target.
- Linear regression: y = θᵀx, MSE loss → gradient is (\hat y - y) x.
- Logistic regression: p = \sigma(θᵀx), binary cross-entropy loss → gradient is (p - y) x.
- Softmax classification: p = softmax(Wx), cross-entropy loss → gradient is (p − y) xᵀ.
Three different problems, same form of gradient. This is why MSE for regression, BCE for binary classification, and CE for multinomial classification are the “natural” losses for their problems — they’re the ones whose gradients collapse to the clean form.
Principle (generalised linear models): for any exponential-family distribution, there’s a canonical link function connecting the linear predictor to the distribution’s natural parameter. When you pair the canonical link with the distribution’s negative log-likelihood as the loss, the resulting loss/parameter relationship satisfies ∂L/∂(linear predictor) = prediction − target — always, regardless of distribution.
Examples (canonical link, loss): identity + MSE for Gaussian; logit + BCE for Bernoulli; softmax + CE for Multinomial; log + Poisson for Poisson. All have the same gradient form.
Operational consequences:
- Backward pass for the loss head is one vector subtraction. The most expensive part of a typical NN’s backward (the loss head’s gradient) becomes the cheapest.
- Numerical stability: combining softmax + CE into a single op (log-softmax + nll loss) avoids overflow at large logits AND exposes this clean gradient form natively. Every framework fuses these.
- When you pair a NON-canonical loss with a head (e.g., MSE with softmax outputs), the gradient is messy AND the optimisation is harder. Stick to canonical pairings unless you have a specific reason not to.
This is why the conventional choices (regression: MSE+linear; binary: BCE+sigmoid; multi-class: CE+softmax) are not arbitrary tradition — they’re the choices whose math collapses cleanly. Modern losses (focal, contrastive) deviate from canonical only when there’s a specific structural reason; even then, they usually preserve the prediction − target structure approximately.
END OF CH.9 §1 — Vector Jacobians & the VJP.
Built: BackpropFlow viz (toggle between forward and backward through a 2-input, 3-hidden linear+ReLU+MSE network; see VJP formulas at each node). Three recall items: easy (why we use VJP not Jacobian), medium (deriving matmul VJP from chain rule), hard (canonical link function principle behind softmax+CE).
Coming next: §9.2 — The backprop algorithm. Tape, topological order, activation memory, gradient checkpointing.