Chapter 10
12 min read
Section 50 of 104

Health Loss Component: Cross-Entropy

AMNL: The Novel Loss Function

Learning Objectives

By the end of this section, you will:

  1. Understand the health classification task as an auxiliary objective
  2. Derive the cross-entropy loss for 3-class classification
  3. Explain the regularization role of health classification
  4. Handle class imbalance appropriately
  5. Implement the health loss in PyTorch
Why This Matters: The health classification task is not merely an additional output—it is the regularizer that enables AMNL to achieve state-of-the-art RUL prediction. By forcing the encoder to learn discrete degradation stages, it prevents overfitting to noisy RUL labels.

The Health Classification Task

Health classification discretizes the degradation trajectory into three states.

Class Definitions

ClassNameRUL RangeInterpretation
0HealthyRUL > 125Normal operation
1Degrading50 < RUL ≤ 125Degradation detected
2CriticalRUL ≤ 50Imminent failure

Label Generation

Health labels are derived automatically from RUL:

ci={0if yi>1251if 50<yi1252if yi50c_i = \begin{cases} 0 & \text{if } y_i > 125 \\ 1 & \text{if } 50 < y_i \leq 125 \\ 2 & \text{if } y_i \leq 50 \end{cases}

Why This Discretization?

  • Boundary at 125: Aligns with piecewise linear RUL cap
  • Boundary at 50: Defines "critical" zone requiring action
  • Three classes: Sufficient granularity without over-complication

Cross-Entropy Loss

We use standard cross-entropy for health classification.

Mathematical Formulation

Lhealth=1Ni=1Nlogp(cixi)\mathcal{L}_{\text{health}} = -\frac{1}{N}\sum_{i=1}^{N} \log p(c_i | x_i)

Where:

  • ci{0,1,2}c_i \in \{0, 1, 2\}: True health class for sample i
  • p(cixi)p(c_i | x_i): Predicted probability for the correct class
  • NN: Batch size

Expanded Form

With softmax probabilities:

p(cx)=exp(zc)k=02exp(zk)p(c | x) = \frac{\exp(z_c)}{\sum_{k=0}^{2} \exp(z_k)}

The loss becomes:

Lhealth=1Ni=1N[zcilogk=02exp(zk)]\mathcal{L}_{\text{health}} = -\frac{1}{N}\sum_{i=1}^{N} \left[ z_{c_i} - \log\sum_{k=0}^{2} \exp(z_k) \right]

Regularization Role

Health classification acts as a powerful regularizer for RUL prediction.

How Regularization Works

📝text
1Without health task:
2  Encoder ─→ RUL Head ─→ RUL Prediction
3       └── May overfit to exact cycle numbers
4
5With health task:
6  Encoder ─┬→ RUL Head    ─→ RUL Prediction
7           └→ Health Head ─→ Health Classification
8       └── Must learn features useful for BOTH tasks

Why This Helps RUL

  • Coarse supervision: Health classes provide "checkpoints" along degradation
  • Noise reduction: Discrete classes are less noisy than exact RUL
  • Feature regularization: Encoder must learn generalizable features
  • Gradient diversity: Different loss gradients stabilize training

Empirical Evidence

Removing health classification dramatically hurts RUL performance:

ConfigurationFD002 ScoreDegradation
AMNL (dual-task)1,102
RUL only (single-task)4,453+304%

Critical Finding

Without the health classification auxiliary task, RUL prediction performance degrades by over 300%. The health task is not optional—it is essential for achieving state-of-the-art results.


Implementation

The health loss uses PyTorch's built-in cross-entropy.

Basic Implementation

Health Classification Loss
🐍losses/health_loss.py
2Logits Input

Raw network outputs before softmax. Each row has 3 values for the 3 health classes (Healthy, Degrading, Critical).

EXAMPLE
# BEFORE softmax: raw logits from health head
logits = tensor([[1.5, 0.8, -0.5],   # Sample 1
                 [0.2, 1.9, 0.4],   # Sample 2
                 [-0.3, 0.1, 2.1],  # Sample 3
                 [2.0, 0.5, -1.0]]) # Sample 4
# Shape: (4, 3) = (batch, num_classes)

# F.cross_entropy applies softmax internally
3Target Classes

Integer class labels (0=Healthy, 1=Degrading, 2=Critical). Derived from RUL thresholds.

EXAMPLE
# Target health classes:
target = tensor([1, 1, 2, 0])
#               ↑  ↑  ↑  ↑
#    Degrading  Degrading  Critical  Healthy

# Corresponding RUL values were:
# Sample 1: RUL=80 → class 1 (50 < RUL ≤ 125)
# Sample 2: RUL=65 → class 1 (50 < RUL ≤ 125)
# Sample 3: RUL=15 → class 2 (RUL ≤ 50)
# Sample 4: RUL=150 → class 0 (RUL > 125)
15Cross-Entropy Computation

F.cross_entropy combines log_softmax + NLLLoss. Computes -log(probability of correct class) for each sample.

EXAMPLE
# For sample 1: logits=[1.5, 0.8, -0.5], target=1
# Step 1: Softmax
softmax = exp([1.5, 0.8, -0.5]) / sum
        = [4.48, 2.23, 0.61] / 7.32
        = [0.61, 0.30, 0.08]

# Step 2: Negative log of correct class (index 1)
loss_1 = -log(0.30) = 1.20

# Averaged over batch:
loss = mean([1.20, 0.18, 0.12, 0.95]) = 0.61
12 lines without explanation
1def health_loss(
2    logits: torch.Tensor,
3    target: torch.Tensor
4) -> torch.Tensor:
5    """
6    Cross-entropy loss for health classification.
7
8    Args:
9        logits: Health head output (batch, 3)
10        target: True health class (batch,)
11
12    Returns:
13        Cross-entropy loss (scalar)
14    """
15    return F.cross_entropy(logits, target)

With Class Weights (Optional)

For handling class imbalance:

Weighted Health Classification Loss
🐍losses/health_loss.py
17Default Class Weights

Inverse frequency weights to balance imbalanced classes. Critical class (2) is rarest, so gets highest weight.

EXAMPLE
# Typical class distribution in training:
# Class 0 (Healthy): 40% of samples
# Class 1 (Degrading): 35% of samples
# Class 2 (Critical): 25% of samples

# Weights inversely proportional to frequency:
class_weights = tensor([1.0, 1.2, 1.5])
#                        ↑    ↑    ↑
#               Healthy  Degrading  Critical

# Effect: Critical misclassifications penalized 1.5×
# compared to Healthy misclassifications
20Weighted Cross-Entropy

PyTorch's weight parameter scales per-sample losses by the weight of their true class before averaging.

EXAMPLE
# Without weights (uniform):
losses = [1.20, 0.18, 0.12, 0.95]  # per-sample
loss = mean(losses) = 0.61

# With weights = [1.0, 1.2, 1.5]:
# targets = [1, 1, 2, 0]
# weights for each sample: [1.2, 1.2, 1.5, 1.0]

weighted_losses = [1.44, 0.22, 0.18, 0.95]
loss = sum(weighted_losses) / sum([1.2, 1.2, 1.5, 1.0])
     = 2.79 / 4.9 = 0.57
21Device Transfer

Move class_weights to same device as logits (CPU or GPU). Required for tensor operations to work.

EXAMPLE
# If logits are on GPU:
logits.device  # device(type='cuda', index=0)

# Weights must also be on GPU:
class_weights.to(logits.device)
# tensor([1.0, 1.2, 1.5], device='cuda:0')

# Without .to(), you'd get:
# RuntimeError: Expected all tensors on same device
21 lines without explanation
1def health_loss_weighted(
2    logits: torch.Tensor,
3    target: torch.Tensor,
4    class_weights: torch.Tensor = None
5) -> torch.Tensor:
6    """
7    Weighted cross-entropy for imbalanced classes.
8
9    Args:
10        logits: Health head output (batch, 3)
11        target: True health class (batch,)
12        class_weights: Weight per class (3,)
13
14    Returns:
15        Weighted cross-entropy loss (scalar)
16    """
17    if class_weights is None:
18        # Typical class distribution weights
19        class_weights = torch.tensor([1.0, 1.2, 1.5])
20
21    return F.cross_entropy(
22        logits, target,
23        weight=class_weights.to(logits.device)
24    )

Typical Loss Values

Training StageTypical ValueInterpretation
Early (epoch 1-10)1.0-1.5Near random guessing
Mid (epoch 10-50)0.4-0.8Learning class boundaries
Late (epoch 50+)0.2-0.4Good classification
Converged0.1-0.3Reliable predictions

Loss Monitoring

Health loss should decrease smoothly during training. Sudden increases may indicate learning rate issues or label noise.


Summary

In this section, we examined the health classification loss:

  1. Task: 3-class classification (Healthy, Degrading, Critical)
  2. Loss: Standard cross-entropy
  3. Role: Regularizer for RUL prediction
  4. Impact: Removing it degrades RUL by 304%
  5. Implementation: F.cross_entropy in PyTorch
PropertyValue
Number of classes3
Class boundaries125, 50 cycles
Typical converged loss0.2-0.4
Performance impactEssential for SOTA
Looking Ahead: We have defined both loss components. The next section explains why the 0.5/0.5 split provides superior regularization—the theoretical justification for AMNL's equal weighting strategy.

With both loss components defined, we analyze why equal weighting is optimal.