Why naive softmax causes NaNs — and how LogSumExp saves your run
TL;DR: Implementing softmax as exp(logits)/sum(exp(logits)) will occasionally overflow or underflow on real hardware and produce infinite losses and NaN gradients — wasting GPU hours and derailing training. The simple, bulletproof fix is to compute a numerically stable cross-entropy from raw logits using the LogSumExp trick (or use your framework’s fused softmax+loss like PyTorch’s CrossEntropyLoss or TensorFlow’s tf.nn.softmax_cross_entropy_with_logits). This prevents infinities, keeps gradients finite, and is standard practice for production training.
An analogy that makes the bug obvious
Softmax is like turning raw exam scores into percentiles. Naive exp-and-normalize is like blowing up the top student’s score until the scale breaks: one very large score dominates and the rest fall to zero, and the math on your GPU interprets that as “infinite” or “zero” in ways that break training. Shifting the scores before exponentiating compresses the scale so nothing blows up — that’s LogSumExp.
The symptom: sudden NaN losses
Teams scale up batch sizes, use larger models, or try mixed-precision training and then a run dies mid-epoch with NaN loss or NaN gradients. Once NaNs appear, they propagate through parameters and optimizers; the usual reaction is a frustrating restart and a forensic hunt for bugs. Often the culprit is a naive softmax implementation that didn’t account for finite-precision arithmetic on GPUs.
Key terms (one-line definitions)
- Logits — raw model scores before any probability mapping.
- Overflow / underflow — numbers too large or too small for the floating-point format your hardware uses (the common GPU format, e.g., FP32/FP16).
- LogSumExp — a numerically stable identity for computing log of a sum of exponentials.
- Numerically stable softmax — softmax computed or fused with cross-entropy so intermediate values don’t become Inf or 0.
Why naive softmax fails (short, concrete)
Softmax is exp(x_i) / sum_j exp(x_j). If any x_i is very large (say 1000), exp(1000) overflows to +∞ in FP32/FP16 and exp(-1000) underflows to 0. If the correct class probability becomes exactly 0, -log(0) → +∞, and backprop produces NaN gradients. The math on paper is fine; the finite precision of hardware is the problem.
A tiny numeric example
| Logits | Naive exp results | Shifted exp results (subtract max) |
|---|---|---|
| [2.0, 1.0, 0.1] | [e^2≈7.39, e^1≈2.72, e^0.1≈1.11] | [e^0≈1.00, e^-1≈0.37, e^-1.9≈0.15] |
| [1000.0, 1.0, -1000.0] | [e^1000→+∞, e^1≈2.72, e^-1000≈0] | [e^0=1.0, e^-999≈0, e^-2000≈0] |
| [3.0, 2.0, 1.0] | [e^3≈20.09, e^2≈7.39, e^1≈2.72] | [e^0≈1.00, e^-1≈0.37, e^-2≈0.14] |
The second row shows the problem: naive exponentiation produces +∞ and 0. Subtracting the per-sample max maps the largest value to zero (exp(0)=1), avoiding overflow and keeping everything finite.
The LogSumExp identity (the safe way)
Use this identity to compute the log of a sum of exponentials safely:
logsumexp(x) = m + log(sum(exp(x – m))) where m = max(x).
Explanation: subtracting m makes the largest term 0 so exp(0)=1 and all other exponents are ≤1, preventing overflow. After summing those safe exponentials, add m back in log-space. When computing cross-entropy from logits you can operate in the log domain and avoid explicit probabilities entirely.
How this plugs into cross-entropy
Cross-entropy for a sample with logits x and correct class k can be written:
loss = -x_k + logsumexp(x)
Using the logsumexp identity above yields a numerically stable loss expression that doesn’t require computing softmax probabilities first.
Minimal PyTorch reproduction and fix
Naive (can explode):
import torch logits = torch.tensor([[1000.0, 1.0, -1000.0]]) targets = torch.tensor([0]) # naive softmax + negative log-likelihood (poor style) probs = torch.exp(logits) / torch.exp(logits).sum(dim=1, keepdim=True) loss = -torch.log(probs[0, targets[0]]) print(loss)
Stable (preferred):
import torch logits = torch.tensor([[1000.0, 1.0, -1000.0]]) targets = torch.tensor([0]) # recommended: use fused, stable CrossEntropyLoss loss_fn = torch.nn.CrossEntropyLoss() loss = loss_fn(logits, targets) print(loss)
Or if you implement manually, use log-softmax + NLL:
import torch.nn.functional as F log_probs = F.log_softmax(logits, dim=1) loss = F.nll_loss(log_probs, targets)
These use the log-domain trick under the hood and keep intermediate values finite.
Mixed-precision and production notes
Lower-precision formats (FP16, TF32, BFloat16) reduce dynamic range and increase the odds of hitting overflow/underflow. Best practices:
- Prefer fused softmax+cross-entropy ops provided by frameworks — they’re optimized for speed and numerical stability.
- Use automatic mixed precision (AMP) and loss scaling (e.g., torch.cuda.amp) to prevent tiny gradients or overflow in FP16 paths.
- Run simple extreme-logit unit tests in CI to make sure no NaNs appear for pathological inputs.
Monitoring, debugging and quick checks
- At the start of runs, log the per-batch max and min of logits; extremes are an early warning sign.
- Fail fast: assert torch.isfinite(loss).all() for the first N batches (simple CI test name: test_extreme_logits_no_nan).
- Check gradients: if any param.grad has NaNs, print a few top-layer logits — the problem is often numerical, not model architecture.
- If you see a sudden spike to +Inf or NaN in loss, first verify you aren’t using a manual exp-and-normalize softmax anywhere in your forward pass.
Edge cases and false positives
NaNs can come from other sources too — bad data (divide-by-zero, corrupted labels), unstable optimizers (very large learning rate), or numerical issues in batchnorm/layernorm. Use the checks above to triage: if only the softmax path uses exp/normalize manually, that’s the usual culprit.
How to prove it (1–2 line derivation)
Naive cross-entropy: -log( exp(x_k) / sum_j exp(x_j) ) = -x_k + log(sum_j exp(x_j)). Replace log(sum_j exp(x_j)) with m + log(sum_j exp(x_j – m)) to get a stable formula that never requires computing exp of the original large x_j.
Actionable checklist for engineering managers and ML leads
- Ask the team: do we use framework fused CrossEntropyLoss (PyTorch) / tf.nn.softmax_cross_entropy_with_logits (TensorFlow)? If not, prioritize switching.
- Ensure mixed-precision runs use AMP and loss scaling, and that fused ops are available in the FP16 code path.
- Add a CI test that feeds extreme logits into the loss function and asserts no Inf/NaN in loss or gradients.
- Monitor first N batches’ logits and loss finiteness when scaling experiments to new batch sizes or multi-GPU setups.
FAQ — quick answers
- Why does naive softmax sometimes produce NaNs?
Because exp(logits) can overflow to +Inf or underflow to 0 on finite-precision hardware, and taking -log(0) gives +Inf in the loss, which produces NaN gradients during backprop.
- How do I implement a numerically stable classification loss?
Either use your framework’s fused CrossEntropyLoss (which accepts logits directly) or compute loss = -x_k + logsumexp(x) with logsumexp(x) = m + log(sum(exp(x – m))).
- Are explicit probabilities necessary during training?
No — log-domain quantities (log-softmax or fused log-sum formulation) are sufficient for loss and gradients and are numerically safer than computing probabilities first.
- Does label smoothing or clipping logits fix overflow?
Those can reduce extremes but are band-aids. Subtracting the per-sample max (LogSumExp) is the correct numerical fix; label smoothing addresses overconfidence, not overflow mechanics.
Further reading / references
- Search for: “LogSumExp numerical stability” for math intuition and proofs.
- PyTorch docs: CrossEntropyLoss (look up “CrossEntropyLoss logits”) — this is the fused, stable op to use.
- TensorFlow docs: tf.nn.softmax_cross_entropy_with_logits — the framework equivalent that accepts logits.
Treat fused, log-domain softmax + cross-entropy as basic hygiene for any production ML pipeline. It’s a tiny change in code with outsized impact on training stability, reliability, and cluster costs.