SOFTMAX & THE EXPONENTIAL FAMILY
Section 12.1
01

Softmax — properties + numerical stability

Almost every probability distribution an LLM produces is the output of a softmax. Cross-entropy loss takes a softmax output and a one-hot target. Attention scores are softmax-weighted (Ch.13). Sampling for generation reads from a softmax distribution. The operation is conceptually a one-liner — turn logits into probabilities by exponentiating and normalising — but the float-precision details get you in trouble within minutes if you don’t know them. This section covers (1) the definition and three properties that make softmax the right choice, (2) the max-subtraction trick that every production implementation uses, and (3) the temperature parameter that controls how “sharp” or “diffuse” the resulting distribution is.

The definition

For a vector of real-valued logits z = (z_1, …, z_C):

softmax(z)_i = exp(z_i) / Σ_j exp(z_j) Properties (by construction): softmax(z)_i > 0 for every i (output is non-negative) Σ_i softmax(z)_i = 1 (outputs sum to 1 — a probability distribution) softmax(z + c · 1) = softmax(z) (shift-invariant — adding a constant doesn't change the output)

The shift invariance is the one operational property to commit to memory; it’s what enables the max-subtraction trick below.

Three structural reasons softmax is the default for “turn this real vector into a probability distribution”:

  1. Smooth. Unlike argmax (which is discontinuous when two logits tie), softmax is smooth and differentiable everywhere. Critical for gradient-based training.
  2. Probabilistic interpretation. The output is a valid probability distribution. You can sample from it; you can take cross-entropy with a label; you can compute KL divergence against another distribution. The softmax output behaves “as if” it came from a probabilistic model.
  3. Canonical link for multinomial classification. The softmax + cross-entropy pair gives the clean ∂L/∂z = p − y gradient (Ch.9 §1). This is the canonical-link-function identity that recurs anywhere “predict a category” matters.
— think, then check —

softmax(z)_i = exp(z_i) / Σ_j exp(z_j).

Properties:

  • Output is non-negative everywhere.
  • Outputs sum to 1 — a valid probability distribution.
  • Shift-invariance: softmax(z + c·1) = softmax(z) for any constant c.

The shift-invariance follows from exp(z_i + c) / Σ exp(z_j + c) = e^c · exp(z_i) / (e^c · Σ exp(z_j)) = exp(z_i) / Σ exp(z_j). The c cancels in numerator and denominator.

This is what makes the max-subtraction trick valid: replace z with z − max(z), guaranteeing all exponents are ≤ 0, so all the exp values are ≤ 1 — no overflow risk. The output is mathematically identical to the original softmax.

The numerical landmines

A naïve implementation of softmax — literally compute the exponentials and normalise — fails on real inputs in two ways:

Logits z arrive with arbitrary scale. Real inference logits often hit |z| ≈ 50–100. Naïve: exp(z_i) for z_i = 80: exp(80) ≈ 5.5 × 10³⁴ OVERFLOW (float32 max ≈ 3.4 × 10³⁸) for z_i = 100: exp(100) ≈ 2.7 × 10⁴³ OVERFLOW (definitely) for z_i = -100: exp(-100) ≈ 3.7 × 10⁻⁴⁴ UNDERFLOW (rounds to 0) If even one entry overflows, the sum is Inf and the softmax output is NaN. If many entries underflow, the small ones become exactly zero and you lose the gradient signal through them.

Production training routinely sees logits in the ±50 range; long-context inference can spike to ±100. Naïve softmax in float32 fails on real workloads.

The max-subtraction trick

The fix is one line, exploiting the shift invariance:

m = max_j z_j softmax(z)_i = exp(z_i − m) / Σ_j exp(z_j − m) After subtracting m: z_i − m ≤ 0 for every i (since m is the max) exp(z_i − m) ≤ 1 (no overflow) exp(0) = 1 (at least one value is exactly 1) So the sum is in [1, C] where C is vocab size — finite, safe. The denominator can never underflow to zero — that 'exp(0) = 1' guarantees it.

Max-subtraction (also called the log-sum-exp trick when applied in log space) is what every production softmax kernel uses. PyTorch, JAX, TensorFlow, llama.cpp — all do this. The implementation is ~3 lines:

void softmax_stable(const float* z, float* out, int N) {
    float m = z[0];
    for (int i = 1; i < N; i++) if (z[i] > m) m = z[i];   // find max
    float sum = 0;
    for (int i = 0; i < N; i++) {
        out[i] = expf(z[i] - m);                          // shift then exp
        sum += out[i];
    }
    for (int i = 0; i < N; i++) out[i] /= sum;            // normalise
}

Three passes over the input: find max, compute exponentials, divide by sum. O(N) work, O(1) extra state.

The pass-count matters more than it sounds. Naïve “compute exp then sum then divide” is two passes. Max-subtraction adds a third pass to find the max. For a vocabulary of 128K, that’s 3× the memory bandwidth — the dominant cost for large-vocab softmax at inference time. The online softmax algorithm in §12.3 combines all three passes into one pass with running state, which is what makes FlashAttention’s block-streaming approach possible.

— think, then check —

The likeliest cause: their softmax computation doesn’t use the max-subtraction trick. Production logits often hit |z| ≈ 50–100, which makes naïve exp(z_i) overflow to Inf. The sum of Inf is Inf, and dividing Inf by Inf gives NaN.

One-line fix: subtract max(z) from every logit before exponentiating. By shift-invariance, this doesn’t change the result mathematically. But every exp(z_i − max) is now ≤ 1, so overflow is impossible. The denominator is also at least 1 (since exp(0) = 1 for the maximum logit), so no underflow either.

Standard implementation in every framework:

1. m = max(z)
2. out_i = exp(z_i − m)
3. out_i /= Σ out_j

Also fused into log-softmax + cross-entropy (the F.cross_entropy / nn.CrossEntropyLoss path) to avoid the intermediate softmax materialisation entirely. Production code path: log_softmax(z) = z − m − log(Σ exp(z − m)); cross-entropy is just the negative of the true class’s log-softmax. No softmax probabilities are constructed explicitly.

Temperature — controlling the sharpness

The softmax has a tuning knob hidden inside it that’s exposed at inference time. Temperature τ scales the logits before the softmax:

softmax_τ(z)_i = exp(z_i / τ) / Σ_j exp(z_j / τ) τ → 0⁺ (sharp) softmax approaches argmax — most probable token dominates τ = 1 (standard) original softmax τ → ∞ (diffuse) softmax approaches uniform — all tokens roughly equal probability

Temperature directly controls how much “randomness” the model uses at inference. τ = 0.7 is a common default — slightly sharper than the trained distribution, biasing toward higher-probability completions while still leaving some diversity. τ = 0 (or τ = 10⁻⁶) gives greedy decoding (always the most probable token). τ = 1.5 gives creative writing with more diversity but more error.

Operational rule: temperature is for inference-time control, not training. Training always uses τ = 1 — the model is being optimised against the canonical softmax distribution, and changing it at training time changes the gradient. At inference, you can pick whatever τ you want without retraining; the underlying logits are unchanged.

— think, then check —

Entropy of a probability distribution p is H(p) = -Σ p_i log p_i. For softmax with temperature τ:

τ → 0: distribution concentrates on the argmax → H(p) → 0 (deterministic).

τ = 1: original distribution → standard entropy.

τ → ∞: distribution becomes uniform → H(p) → log(C) (maximum possible).

So temperature directly controls the entropy of the output distribution. Lower τ = lower entropy = more deterministic = more focused; higher τ = higher entropy = more diverse.

Operational consequence: sampling from a softmax with low τ produces a sequence that’s high-confidence and predictable but susceptible to repetition (the model commits to one path and stays there). High τ produces creative but error-prone output (the model gets lured into low-probability tokens that may not make sense). The trade is between consistency and diversity.

Production sampling almost never uses pure temperature alone — it’s combined with top-k (only sample from the k most probable tokens) or top-p (only sample from the smallest set of tokens whose cumulative probability exceeds p). These cap the worst-case error of high-temperature sampling while preserving diversity. Standard LLM inference uses some combination of temperature 0.7–1.0 + top-p 0.9–0.95 — empirically the sweet spot for creative generation that’s still coherent.

END OF CH.12 §1 — Softmax + numerical stability.
Three recall items: easy (definition + shift invariance), medium (max-subtraction trick under a production NaN scenario), hard (temperature’s entropy interpretation and its operational consequence for sampling).
Coming next: §12.2 — Cross-entropy as KL divergence; the canonical-link argument made information-theoretic.