A Multiple-Choice Exam, Translated to Math
A multiple-choice question gives you four options and asks you to pick one. You might feel most confident about option B (60%), less confident about C (25%), and almost rule out A and D (10% each). The teacher then reveals the correct answer was C — you were confidently wrong, and your “loss” for this question is higher than if you had been confidently right or even uniformly unsure.
That informal reasoning is exactly the math behind every classification network. Softmax converts a vector of raw scores into a probability distribution. Cross-entropy measures how badly that distribution differs from the truth. Together they are the loss for the health-classification head we will attach to the backbone in Chapter 11 — one of the two heads of every dual-task model in this book.
Softmax: From Logits to Probabilities
Given a vector of raw scores (called logits) , softmax produces a probability distribution :
Three properties make it the right choice. (1) Non-negative: for every . (2) Sums to 1: divisor is the sum of numerators. (3) Smooth and differentiable: gradients flow back to every . The C-class health head in Chapter 11 has — normal, degrading, critical.
Interactive: Drag Logits, Watch Probabilities
The visualization below lets you drag four token logits and see the softmax probabilities update live. Toggle the temperatureslider: high temperature flattens the distribution toward uniform, low temperature sharpens it toward one-hot. The same temperature knob shows up in language-model sampling where it controls creativity vs determinism.
Interactive Softmax Visualizer
Adjust logits and watch probabilities change in real-time
Input Logits (adjust with sliders)
Cross-Entropy: Distance Between Distributions
With the predicted distribution and the target distribution (one-hot for hard labels), the cross-entropy is
For one-hot with the correct class at index , this collapses to
Three intuitions. (a) If (perfect): loss is . (b) If : loss is . (c) If : loss . Confidence on the wrong answer is punished without bound.
Interactive: The Cross-Entropy Loss
Watch the loss spike as the model becomes confidently wrong. The gradient flows back through log and softmax to update the logits.
Interactive Cross-Entropy Loss
See how the loss penalizes wrong predictions
Model's Probability Distribution
Calculation Breakdown
The negative log function penalizes confident wrong predictions severely. If the model says 1% for the correct answer, loss = -log(0.01) = 4.6. But if it says 90%, loss = -log(0.9) = 0.1. This gradient pushes the model to be both correct and confident.
Since y is one-hot encoded, only the probability of the correct class matters
Numerical Stability: The Max-Subtract Trick
Naive softmax is a numerical landmine. overflows to inf; underflows to 0. The fix is to exploit a softmax invariant:
Choose . After subtracting, every , so — safe. The biggest logit becomes 0; ; the denominator is at least 1. No overflow, no underflow, identical result.
F.softmax and F.cross_entropy handle it; if you build a custom layer, replicate the trick.Python: Softmax and CE From Scratch
Twenty lines of NumPy on a four-engine, three-class C-MAPSS-style batch. The numerical values match what PyTorch's F.cross_entropy returns to four decimal places.
Reading the per-row losses
Row 2 (logits , target 2) produces softmax probability 0.96 on the correct class — loss 0.041. Row 3 (logits , target 1) is much less confident — probability 0.40, loss 0.91. The average (0.37) is what the optimiser drives toward 0 during training.
PyTorch: F.cross_entropy in One Line
Production code never writes the softmax + log + NLL pipeline by hand — F.cross_entropy fuses them in a single, numerically-stable, GPU-friendly call. The CRITICAL convention is that you pass raw logits, not pre-softmaxed probabilities.
F.cross_entropy(F.softmax(logits, dim=-1), targets) feeds probabilities where logits are expected. The model will train but the loss is wrong — you are taking softmax of a softmax, not log_softmax of logits. Always pass raw logits.Cross-Entropy Beyond RUL
| Domain | Classes | Famous use |
|---|---|---|
| RUL health classification (this book) | 3 (normal / degrading / critical) | Auxiliary head in dual-task model |
| Image classification | 1000 (ImageNet) / 10 (MNIST) | Standard ResNet / ViT loss |
| Language modelling | ~50,000 (vocab size) | GPT, BERT next-token prediction |
| Speech recognition | ~30 phonemes | DeepSpeech CTC loss is a CE generalisation |
| Medical diagnosis | ~10-100 conditions | Cancer subtype classification |
| Recommendation | Item catalogue | Sampled softmax in retrieval models |
| Genomics | 4 nucleotides | DNA base prediction (Enformer) |
| Self-driving | ~20 object classes | Pedestrian / cyclist / sign detection |
The mathematics never changes. What changes is — the number of classes — and the dataset.
The Three Pitfalls
F.cross_entropy(logits, targets) is correct. F.cross_entropy(softmax(logits), targets) is silently wrong. PyTorch's naming convention is brittle — the function should arguably be called “cross_entropy_with_logits” (TensorFlow does this).F.cross_entropy(weight=class_weights) to up-weight rare classes — or use focal loss (Section 11.3) which down- weights well-classified examples automatically.torch.tensor([0, 1, 2, 1]), not torch.eye(3)[[0, 1, 2, 1]].The point. Softmax is the differentiable bridge from real-valued logits to a probability distribution. Cross-entropy is the loss that punishes confident mistakes. Together they form the classification head used in every dual-task model from Chapter 11 onward.
Takeaway
- Softmax = soft argmax. Converts logits to a probability distribution that sums to 1.
- Cross-entropy = surprise. : zero when the model nails the right class, blows up when it is confidently wrong.
- Always max-subtract before exp. Otherwise overflows. The result is mathematically identical because softmax is shift-invariant.
- F.cross_entropy expects raw logits. Never pre-softmax the input; the function fuses log_softmax + NLL internally for stability.
- Targets are integer class IDs. Not one-hot. PyTorch convention.