Chapter 3
14 min read
Section 13 of 121

Softmax & Cross-Entropy

Mathematical Preliminaries

A Multiple-Choice Exam, Translated to Math

A multiple-choice question gives you four options and asks you to pick one. You might feel most confident about option B (60%), less confident about C (25%), and almost rule out A and D (10% each). The teacher then reveals the correct answer was C — you were confidently wrong, and your “loss” for this question is higher than if you had been confidently right or even uniformly unsure.

That informal reasoning is exactly the math behind every classification network. Softmax converts a vector of raw scores into a probability distribution. Cross-entropy measures how badly that distribution differs from the truth. Together they are the loss for the health-classification head we will attach to the backbone in Chapter 11 — one of the two heads of every dual-task model in this book.

The mental model. Softmax = a soft argmax that outputs probabilities. Cross-entropy = the “surprise” when the true class shows up.

Softmax: From Logits to Probabilities

Given a vector of CC raw scores (called logits) zRC\mathbf{z} \in \mathbb{R}^{C}, softmax produces a probability distribution p[0,1]C\mathbf{p} \in [0,1]^{C}:

pc  =  ezck=1Cezk.p_c \;=\; \frac{e^{z_c}}{\sum_{k=1}^{C} e^{z_k}}.

Three properties make it the right choice. (1) Non-negative: ezc>0e^{z_c} > 0 for every cc. (2) Sums to 1: divisor is the sum of numerators. (3) Smooth and differentiable: gradients flow back to every zcz_c. The C-class health head in Chapter 11 has C=3C = 3 — normal, degrading, critical.

Interactive: Drag Logits, Watch Probabilities

The visualization below lets you drag four token logits and see the softmax probabilities update live. Toggle the temperatureslider: high temperature flattens the distribution toward uniform, low temperature sharpens it toward one-hot. The same temperature knob shows up in language-model sampling where it controls creativity vs determinism.

Interactive Softmax Visualizer

Adjust logits and watch probabilities change in real-time

Input Logits (adjust with sliders)

catz = 2.00
dogz = 1.00
birdz = 0.50
fishz = -1.00
Step 1
Raw logits from the model
cat
2.00
dog
1.00
bird
0.50
fish
-1.00
Note: Logits can be any real number (positive or negative). They don't sum to 1 and aren't interpretable as probabilities yet.
The Softmax Formula:
softmax(zi) = ezi / Σ ezj

Cross-Entropy: Distance Between Distributions

With p\mathbf{p} the predicted distribution and q\mathbf{q} the target distribution (one-hot for hard labels), the cross-entropy is

H(q,p)  =  c=1Cqclogpc.H(\mathbf{q}, \mathbf{p}) \;=\; -\sum_{c=1}^{C} q_c \log p_c.

For one-hot q\mathbf{q} with the correct class at index yy, this collapses to

H=logpy.H = -\log p_y.

Three intuitions. (a) If py=1p_y = 1 (perfect): loss is log1=0-\log 1 = 0. (b) If py=0.5p_y = 0.5: loss is log20.693\log 2 \approx 0.693. (c) If py0p_y \to 0: loss \to \infty. Confidence on the wrong answer is punished without bound.

Why log and not error? The log scaling produces gradients that are well-behaved on probabilities (vs MSE on probs which can have flat regions). Cross-entropy is the maximum-likelihood loss for multi-class classification under independent samples.

Interactive: The Cross-Entropy Loss

Watch the loss spike as the model becomes confidently wrong. The gradient flows back through log and softmax to update the logits.

Interactive Cross-Entropy Loss

See how the loss penalizes wrong predictions

Model's Probability Distribution

cat
90.0%
🎯
dog
5.0%
bird
3.0%
fish
2.0%
True Label (One-Hot Encoded):
1
0
0
0
Cross-Entropy Loss
0.1054
✓ Excellent! Low loss = confident & correct

Calculation Breakdown

Correct answer:cat
P(correct):90.00%
log(P):-0.1054
-log(P):0.1054
Low Loss (Good)High Loss (Bad)
02.55+
💡
Key Insight: Why -log(p)?

The negative log function penalizes confident wrong predictions severely. If the model says 1% for the correct answer, loss = -log(0.01) = 4.6. But if it says 90%, loss = -log(0.9) = 0.1. This gradient pushes the model to be both correct and confident.

Cross-Entropy Loss Formula:
L = -Σ yi · log(pi) = -log(pcorrect)

Since y is one-hot encoded, only the probability of the correct class matters

Numerical Stability: The Max-Subtract Trick

Naive softmax is a numerical landmine. e1000e^{1000} overflows to inf; e1000e^{-1000} underflows to 0. The fix is to exploit a softmax invariant:

softmax(z)=softmax(zc)for any constant c.\text{softmax}(\mathbf{z}) = \text{softmax}(\mathbf{z} - c) \quad \text{for any constant } c.

Choose c=maxkzkc = \max_k z_k. After subtracting, every zcc0z_c - c \le 0, so ezcc(0,1]e^{z_c - c} \in (0, 1] — safe. The biggest logit becomes 0; e0=1e^{0} = 1; the denominator is at least 1. No overflow, no underflow, identical result.

Production rule. Never write softmax yourself without the max-subtract. PyTorch's F.softmax and F.cross_entropy handle it; if you build a custom layer, replicate the trick.

Python: Softmax and CE From Scratch

Twenty lines of NumPy on a four-engine, three-class C-MAPSS-style batch. The numerical values match what PyTorch's F.cross_entropy returns to four decimal places.

Softmax + cross-entropy from first principles
🐍softmax_ce_numpy.py
1import numpy as np

Need np.exp, np.log, basic arithmetic.

4def softmax(logits):

Numerically-stable row-wise softmax. Converts a vector of real numbers (logits) into a probability distribution.

EXECUTION STATE
input: logits = (B, C) real-valued scores from the network
returns = (B, C) probabilities; each row sums to 1
6shifted = logits - logits.max(axis=-1, keepdims=True)

Subtract the per-row max. After this, every row's max is 0; np.exp(non-positive) ∈ (0, 1]. Without this trick np.exp(1000) = inf.

EXECUTION STATE
axis=-1 = Operate along the LAST axis (classes within each batch row)
keepdims=True = Keep shape (B, 1) so broadcasting against (B, C) works
→ math invariant = softmax(x) = softmax(x - c) for any constant c. Subtracting max is the maximally-stable choice.
7expd = np.exp(shifted)

Element-wise exponential. With shifted ≤ 0, expd ∈ (0, 1].

8return expd / expd.sum(axis=-1, keepdims=True)

Normalise each row to sum to 1.

12def cross_entropy(logits, targets):

Mean negative log-likelihood of the correct class. Returns a single scalar - the loss the optimiser minimises.

EXECUTION STATE
input: logits = (B, C) real-valued scores
input: targets = (B,) integer class IDs in [0, C)
returns = Single float - the average -log(p_correct)
16probs = softmax(logits)

Convert logits to probabilities.

EXECUTION STATE
probs.shape = (B, C) = (4, 3)
18p_correct = probs[np.arange(len(targets)), targets]

Advanced fancy indexing: pull out probs[i, targets[i]] for every i. The result is a (B,) array of probabilities the model assigned to the CORRECT class.

EXECUTION STATE
np.arange(len(targets)) = [0, 1, 2, 3] - row indices
→ indexing trick = Pairs row idx with target idx. Equivalent to [probs[0, targets[0]], probs[1, targets[1]], ...].
p_correct = [0.6285, 0.7257, 0.9602, 0.4015] approximately
19return float(-np.log(p_correct + 1e-12).mean())

Negative log of each correct-class probability, averaged. The +1e-12 epsilon prevents log(0) = -inf if the model's confidence is exactly zero.

EXECUTION STATE
Output = 0.3715 - lower means the model is more confident on the correct classes
24logits = np.array([...])

A batch of 4 health-state logit vectors. Each row is one engine's prediction; columns are the three classes (normal, degrading, critical).

EXECUTION STATE
logits.shape = (4, 3)
Row 0 = [2.0, 1.0, 0.5] - leans normal (logit 2 highest)
Row 1 = [0.1, 3.0, 0.2] - strong vote for degrading
Row 2 = [0.5, 0.5, 4.0] - very critical
Row 3 = [1.0, 1.5, 1.2] - uncertain, leaning degrading
30targets = np.array([0, 1, 2, 1])

Ground-truth class IDs. The model is correct on rows 0, 1, 2 (high confidence on the right class) and uncertain on row 3.

EXECUTION STATE
targets = [0, 1, 2, 1] - normal / degrading / critical / degrading
32print("probs[0] :", softmax(logits)[0].round(4).tolist())

Inspect the softmax output for row 0.

EXECUTION STATE
Output = probs[0] : [0.6285, 0.2312, 0.1402]
→ sums to 1? = 0.6285 + 0.2312 + 0.1402 = 1.0
→ confident? = 63% on the right class - moderately confident
34print("CE per row :", [...])

Per-row cross-entropy. Row 3 (uncertain) has the highest loss.

EXECUTION STATE
Output = CE per row : [0.4644, 0.3209, 0.0407, 0.9120]
→ row 0 = -log(0.6285) ≈ 0.464
→ row 2 = -log(0.9602) ≈ 0.041 - near zero, model very confident on right class
→ row 3 = -log(0.4015) ≈ 0.912 - much larger, model is uncertain
38print("avg CE :", round(cross_entropy(logits, targets), 4))

Mean across the batch. This is the scalar the optimiser minimises.

EXECUTION STATE
Output = avg CE : 0.3715
→ for reference = Random model: -log(1/3) ≈ 1.099. Our 0.37 is much better.
25 lines without explanation
1import numpy as np
2
3# ----- Numerically-stable softmax -----
4def softmax(logits: np.ndarray) -> np.ndarray:
5    """Row-wise softmax. Subtracts max for numerical stability."""
6    shifted = logits - logits.max(axis=-1, keepdims=True)
7    expd    = np.exp(shifted)
8    return expd / expd.sum(axis=-1, keepdims=True)
9
10
11# ----- Cross-entropy for one-hot targets -----
12def cross_entropy(logits: np.ndarray, targets: np.ndarray) -> float:
13    """logits: (B, C). targets: (B,) integer class IDs.
14    Returns the mean -log(p_correct) across the batch.
15    """
16    probs = softmax(logits)
17    # Pick the probability of the correct class for each row
18    p_correct = probs[np.arange(len(targets)), targets]
19    return float(-np.log(p_correct + 1e-12).mean())
20
21
22# ----- Run on a tiny C-MAPSS-style health-state batch -----
23# Three classes: 0 = normal, 1 = degrading, 2 = critical
24logits = np.array([
25    [2.0, 1.0, 0.5],       # leans normal
26    [0.1, 3.0, 0.2],       # strong degrading
27    [0.5, 0.5, 4.0],       # very critical
28    [1.0, 1.5, 1.2],       # uncertain - leaning degrading
29], dtype=np.float32)
30targets = np.array([0, 1, 2, 1])     # ground-truth health states
31
32print("probs[0]   :", softmax(logits)[0].round(4).tolist())
33# probs[0]: [0.6285, 0.2312, 0.1402]
34print("CE per row :", [
35    -float(np.log(softmax(logits)[i, targets[i]]))
36    for i in range(4)
37])
38print("avg CE     :", round(cross_entropy(logits, targets), 4))
39# avg CE     : 0.3715

Reading the per-row losses

Row 2 (logits [0.5,0.5,4.0][0.5, 0.5, 4.0], target 2) produces softmax probability 0.96 on the correct class — loss 0.041. Row 3 (logits [1.0,1.5,1.2][1.0, 1.5, 1.2], target 1) is much less confident — probability 0.40, loss 0.91. The average (0.37) is what the optimiser drives toward 0 during training.

PyTorch: F.cross_entropy in One Line

Production code never writes the softmax + log + NLL pipeline by hand — F.cross_entropy fuses them in a single, numerically-stable, GPU-friendly call. The CRITICAL convention is that you pass raw logits, not pre-softmaxed probabilities.

F.cross_entropy on the same toy batch
🐍cross_entropy_torch.py
1import torch

Top-level PyTorch.

2import torch.nn.functional as F

F.cross_entropy and F.softmax live here.

5torch.manual_seed(0)

Determinism (no RNG used here, but good habit).

7logits = torch.tensor([...])

Same numbers as the NumPy block. Note we pass RAW LOGITS to F.cross_entropy - NOT probabilities.

EXECUTION STATE
logits.shape = (4, 3)
→ critical convention = F.cross_entropy expects logits and applies log_softmax internally. Passing pre-softmaxed probs gives WRONG (numerically unstable) results.
13targets = torch.tensor([0, 1, 2, 1])

Integer class IDs. F.cross_entropy expects long-tensor target indices, not one-hot encodings.

15loss = F.cross_entropy(logits, targets)

One-line equivalent of the NumPy `cross_entropy` function. Internally: log_softmax(logits) followed by negative-log-likelihood. Numerically stable and GPU-fused.

EXECUTION STATE
F.cross_entropy(input, target, reduction='mean', ...) = Default reduction is 'mean'. Use reduction='none' to get per-sample losses for debugging.
loss.shape = torch.Size([]) - 0-D scalar
loss.requires_grad = True (in real training - logits would be a function of model params)
16probs = F.softmax(logits, dim=-1)

Standalone softmax for inspection only. We do NOT use this for the loss - F.cross_entropy fuses it for stability.

EXECUTION STATE
F.softmax(input, dim) = Numerically stable softmax. dim=-1 means last axis (classes within a row).
18print("logits.shape :", tuple(logits.shape))

Verify input shape.

EXECUTION STATE
Output = logits.shape : (4, 3)
19print("probs[0] :", probs[0].tolist())

Probabilities match the NumPy run to float32 precision.

EXECUTION STATE
Output = probs[0] : [0.6285, 0.2312, 0.1402]
20print("loss :", round(loss.item(), 4))

Same scalar loss as the NumPy run. Confirms the two implementations agree.

EXECUTION STATE
Output = loss : 0.3715
12 lines without explanation
1import torch
2import torch.nn.functional as F
3
4# ----- F.cross_entropy: softmax + log + NLL fused -----
5torch.manual_seed(0)
6
7logits = torch.tensor([
8    [2.0, 1.0, 0.5],
9    [0.1, 3.0, 0.2],
10    [0.5, 0.5, 4.0],
11    [1.0, 1.5, 1.2],
12])
13targets = torch.tensor([0, 1, 2, 1])
14
15loss = F.cross_entropy(logits, targets)
16probs = F.softmax(logits, dim=-1)
17
18print("logits.shape :", tuple(logits.shape))      # (4, 3)
19print("probs[0]     :", probs[0].tolist())
20print("loss         :", round(loss.item(), 4))
21# probs[0]     : [0.6285, 0.2312, 0.1402]
22# loss         : 0.3715
The most common cross_entropy bug in production.Calling F.cross_entropy(F.softmax(logits, dim=-1), targets) feeds probabilities where logits are expected. The model will train but the loss is wrong — you are taking softmax of a softmax, not log_softmax of logits. Always pass raw logits.

Cross-Entropy Beyond RUL

DomainClassesFamous use
RUL health classification (this book)3 (normal / degrading / critical)Auxiliary head in dual-task model
Image classification1000 (ImageNet) / 10 (MNIST)Standard ResNet / ViT loss
Language modelling~50,000 (vocab size)GPT, BERT next-token prediction
Speech recognition~30 phonemesDeepSpeech CTC loss is a CE generalisation
Medical diagnosis~10-100 conditionsCancer subtype classification
RecommendationItem catalogueSampled softmax in retrieval models
Genomics4 nucleotidesDNA base prediction (Enformer)
Self-driving~20 object classesPedestrian / cyclist / sign detection

The mathematics never changes. What changes is CC — the number of classes — and the dataset.

The Three Pitfalls

Pitfall 1: Logits vs probabilities. F.cross_entropy(logits, targets) is correct. F.cross_entropy(softmax(logits), targets) is silently wrong. PyTorch's naming convention is brittle — the function should arguably be called “cross_entropy_with_logits” (TensorFlow does this).
Pitfall 2: Class imbalance. If 95% of training cycles are class “normal”, vanilla cross-entropy will learn to always predict normal. Use F.cross_entropy(weight=class_weights) to up-weight rare classes — or use focal loss (Section 11.3) which down- weights well-classified examples automatically.
Pitfall 3: Targets as one-hot. Some tutorials use one-hot encoded target tensors. PyTorch's F.cross_entropy expects INTEGER class IDs (long-tensor), not one-hot. Pass torch.tensor([0, 1, 2, 1]), not torch.eye(3)[[0, 1, 2, 1]].
The point. Softmax is the differentiable bridge from real-valued logits to a probability distribution. Cross-entropy is the loss that punishes confident mistakes. Together they form the classification head used in every dual-task model from Chapter 11 onward.

Takeaway

  • Softmax = soft argmax. Converts logits to a probability distribution that sums to 1.
  • Cross-entropy = surprise. logpy-\log p_y: zero when the model nails the right class, blows up when it is confidently wrong.
  • Always max-subtract before exp. Otherwise e1000e^{1000} overflows. The result is mathematically identical because softmax is shift-invariant.
  • F.cross_entropy expects raw logits. Never pre-softmax the input; the function fuses log_softmax + NLL internally for stability.
  • Targets are integer class IDs. Not one-hot. PyTorch convention.
Loading comments...