Chapter 8
12 min read
Section 39 of 104

Health Classification Head (3 Classes)

Dual Task Prediction Heads

Learning Objectives

By the end of this section, you will:

  1. Define the three health states and their RUL boundaries
  2. Design the classification head architecture
  3. Understand softmax output as probability distribution
  4. Address class imbalance in health state labels
  5. Connect classification to cross-entropy loss
Why This Matters: Health classification provides discrete checkpoints along the degradation trajectory. By predicting whether an engine is Healthy, Degrading, or Critical, the model learns to recognize qualitative state transitionsβ€”knowledge that transfers to improved RUL prediction.

Health State Definition

We discretize the continuous RUL into three health states based on remaining life.

State Definitions

ClassStateRUL RangeInterpretation
0HealthyRUL > 125Normal operation, no immediate concern
1Degrading50 < RUL ≀ 125Degradation detected, plan maintenance
2CriticalRUL ≀ 50Imminent failure, urgent maintenance

Visual Representation

πŸ“text
1RUL Timeline:
2  ∞ ←────────────────────────────────────────→ 0
3      β”‚         β”‚              β”‚              β”‚
4      β”‚  Healthy β”‚   Degrading  β”‚   Critical  β”‚ FAILURE
5      β”‚  (Class 0)β”‚   (Class 1)  β”‚   (Class 2) β”‚
6      β”‚         β”‚              β”‚              β”‚
7     125       125            50              0
8
9Health State Transition:
10  Healthy β†’ Degrading β†’ Critical β†’ Failure
11     β”‚           β”‚           β”‚
12    Long       Medium      Short
13   horizon    horizon     horizon

Why These Boundaries?

  • 125 cycles: Aligns with piecewise linear RUL capβ€”beyond this, degradation is minimal
  • 50 cycles: Approximately one standard deviation of typical failure prediction errorβ€”a reasonable "danger zone"
  • Three classes: Sufficient granularity without over-complicating the auxiliary task

Label Generation

Health state labels are derived from RUL labels automatically. No additional annotation is needed:

🐍python
1def get_health_label(rul: float) -> int:
2    if rul > 125:
3        return 0  # Healthy
4    elif rul > 50:
5        return 1  # Degrading
6    else:
7        return 2  # Critical

Classification Head Architecture

The health classification head transforms the shared representation into class logits.

Architecture Overview

πŸ“text
1Input: z ∈ ℝ²⁡⁢ (encoder output)
2            ↓
3β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4β”‚  Linear(256, 64)                β”‚
5β”‚  256 Γ— 64 + 64 = 16,448 params  β”‚
6β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7            ↓
8β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
9β”‚  ReLU                           β”‚
10β”‚  0 parameters                   β”‚
11β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
12            ↓
13β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
14β”‚  Dropout(p=0.3)                 β”‚
15β”‚  0 parameters                   β”‚
16β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
17            ↓
18β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
19β”‚  Linear(64, 3)                  β”‚
20β”‚  64 Γ— 3 + 3 = 195 params        β”‚
21β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
22            ↓
23Output: logits ∈ ℝ³ (class scores)

PyTorch Implementation

Health Classification Head
🐍health_head.py
1Class Definition

Health classification head inherits from nn.Module. Maps 256-dim encoder output to 3 class logits.

10Constructor Parameters

Configure head dimensions: input from encoder (256), hidden layer (64), and output classes (3).

EXAMPLE
# Default configuration:
# input_dim=256  (from encoder)
# hidden_dim=64  (smaller than RUL head)
# num_classes=3  (Healthy, Degrading, Critical)
# dropout=0.3    (regularization)
18Sequential Network

nn.Sequential chains layers together. Input flows through each layer in order.

EXAMPLE
# Data flow through Sequential:
# Input: z.shape = (32, 256)
#   ↓ Linear(256, 64)
# Shape: (32, 64)
#   ↓ ReLU()
# Shape: (32, 64)  # negative values β†’ 0
#   ↓ Dropout(0.3)
# Shape: (32, 64)  # 30% zeroed during training
#   ↓ Linear(64, 3)
# Output: (32, 3)  # 3 class logits
19First Linear Layer

Projects 256-dim input to 64-dim hidden space. Weight shape: (64, 256), Bias shape: (64).

EXAMPLE
# BEFORE: z.shape = (32, 256)
# Linear transformation: output = z @ W.T + b

nn.Linear(256, 64)
# Parameters: 256 Γ— 64 + 64 = 16,448

# AFTER: shape = (32, 64)
20ReLU Activation

ReLU(x) = max(0, x). Introduces non-linearity by zeroing negative values.

EXAMPLE
# BEFORE: tensor([-0.5, 0.3, -0.1, 0.8])
nn.ReLU()
# AFTER:  tensor([ 0.0, 0.3,  0.0, 0.8])

# Negative values become 0
# Positive values unchanged
21Dropout Layer

During training, randomly zeros 30% of values. Prevents overfitting. Disabled during evaluation.

EXAMPLE
# Training mode (model.train()):
# BEFORE: tensor([0.5, 0.3, 0.7, 0.2, 0.8])
nn.Dropout(0.3)
# AFTER:  tensor([0.71, 0.0, 1.0, 0.29, 0.0])
#         ↑ scaled up      ↑ dropped

# Eval mode (model.eval()): No change
22Output Linear Layer

Projects 64-dim hidden to 3 class logits. Each output corresponds to one health class.

EXAMPLE
# BEFORE: shape = (32, 64)
nn.Linear(64, 3)
# Parameters: 64 Γ— 3 + 3 = 195

# AFTER: shape = (32, 3)
# logits = tensor([[2.1, 0.8, -0.5],  # Sample 1
#                  [0.2, 1.5,  0.3],  # Sample 2
#                  ...])
# Column 0: Healthy score
# Column 1: Degrading score
# Column 2: Critical score
25Forward Method

Called when you do health_head(z). Returns raw logits, NOT probabilities.

EXAMPLE
# Usage:
z = encoder(x)        # shape: (32, 256)
logits = health_head(z)  # shape: (32, 3)

# Example output:
# logits[0] = [2.1, 0.8, -0.5]
# Highest score (2.1) = Healthy
# Prediction: Class 0 (Healthy)
27 lines without explanation
1class HealthHead(nn.Module):
2    """
3    Health Classification Head: 256 β†’ 64 β†’ 3
4
5    Two-layer MLP that predicts health state class.
6    Output is raw logits (softmax applied in loss function).
7    """
8
9    def __init__(
10        self,
11        input_dim: int = 256,
12        hidden_dim: int = 64,
13        num_classes: int = 3,
14        dropout: float = 0.3
15    ):
16        super().__init__()
17
18        self.head = nn.Sequential(
19            nn.Linear(input_dim, hidden_dim),    # 256 β†’ 64
20            nn.ReLU(),
21            nn.Dropout(dropout),
22            nn.Linear(hidden_dim, num_classes)   # 64 β†’ 3
23        )
24
25    def forward(self, z: torch.Tensor) -> torch.Tensor:
26        """
27        Forward pass.
28
29        Args:
30            z: Encoder output (batch, 256)
31
32        Returns:
33            Class logits (batch, 3)
34        """
35        return self.head(z)

Design Rationale

ChoiceRationale
Hidden dim 64Simpler task needs smaller head than RUL
Output dim 3Three health classes
No output activationSoftmax in CrossEntropyLoss
Smaller than RUL headClassification is auxiliary, needs less capacity

Softmax and Probability Output

The head outputs raw logits; softmax converts these to probabilities.

Logits to Probabilities

pi=eziβˆ‘j=13ezjp_i = \frac{e^{z_i}}{\sum_{j=1}^{3} e^{z_j}}

Where:

  • ziz_i: Logit (raw score) for class i
  • pip_i: Probability of class i
  • βˆ‘ipi=1\sum_i p_i = 1: Probabilities sum to 1

Example Computation

Why Not Apply Softmax in the Head?

PyTorch's CrossEntropyLoss expects raw logits, not probabilities:

  • Numerical stability: CrossEntropyLoss uses log-sum-exp trick internally
  • Efficiency: Avoids computing softmax twice (once in head, once in loss)
  • Convention: Standard practice in PyTorch classification
Training vs Inference
🐍inference.py
2Get Raw Logits

Forward pass through health head. Output is raw scores, not probabilities.

EXAMPLE
# z.shape = (32, 256)  from encoder
logits = health_head(z)
# logits.shape = (32, 3)

# Example:
# logits[0] = tensor([2.1, 0.8, -0.5])
#             Healthy  Degrading  Critical
3Cross-Entropy Loss

Computes loss between logits and true labels. Internally applies softmax + negative log likelihood.

EXAMPLE
# logits = tensor([[2.1, 0.8, -0.5],
#                   [0.2, 1.5,  0.3]])
# labels = tensor([0, 1])  # True classes

loss = F.cross_entropy(logits, labels)
# loss β‰ˆ 0.45  (scalar)

# Internally computes:
# 1. softmax(logits) β†’ probabilities
# 2. -log(prob of correct class)
# 3. average over batch
6Softmax to Probabilities

Convert logits to probabilities that sum to 1. dim=-1 applies softmax over the last dimension (classes).

EXAMPLE
# logits[0] = tensor([2.1, 0.8, -0.5])

probs = F.softmax(logits, dim=-1)
# probs[0] = tensor([0.74, 0.20, 0.06])
#                    ↑sum = 1.0

# Interpretation:
# 74% confident β†’ Healthy
# 20% confident β†’ Degrading
#  6% confident β†’ Critical
7Get Predicted Class

argmax returns the index of the maximum value. This is the predicted class.

EXAMPLE
# probs = tensor([[0.74, 0.20, 0.06],
#                  [0.15, 0.70, 0.15]])

predicted_class = probs.argmax(dim=-1)
# predicted_class = tensor([0, 1])

# Sample 0: max at index 0 β†’ Healthy
# Sample 1: max at index 1 β†’ Degrading
3 lines without explanation
1# Training: use raw logits
2logits = health_head(z)           # (batch, 3)
3loss = F.cross_entropy(logits, labels)
4
5# Inference: apply softmax for probabilities
6probs = F.softmax(logits, dim=-1) # (batch, 3)
7predicted_class = probs.argmax(dim=-1)  # (batch,)

Handling Class Imbalance

The three health classes are not equally represented in the data.

Class Distribution

In a typical C-MAPSS training set, the distribution is approximately:

ClassStateApprox. FractionChallenge
0Healthy~45%Most common (early life)
1Degrading~30%Moderate representation
2Critical~25%Less common but most important

Imbalance Mitigation Strategies

We use focal loss (Chapter 11) to address class imbalance:

Lfocal=βˆ’Ξ±t(1βˆ’pt)Ξ³log⁑(pt)\mathcal{L}_{\text{focal}} = -\alpha_t (1 - p_t)^\gamma \log(p_t)

Where:

  • Ξ±t\alpha_t: Class weight for class t
  • (1βˆ’pt)Ξ³(1 - p_t)^\gamma: Focusing factor (down-weights easy examples)
  • Ξ³=2\gamma = 2: Typical focusing parameter

Focal Loss Intuition

Standard cross-entropy treats all misclassifications equally. Focal loss reduces the loss contribution from easy, well-classified examples (high ptp_t), focusing training on hard examples that are misclassified or borderline.

Class Weights

Alternatively, class weights can balance the loss:

🐍python
1# Inverse frequency weighting
2class_counts = [4500, 3000, 2500]  # Example counts
3total = sum(class_counts)
4class_weights = torch.tensor([total / c for c in class_counts])
5class_weights = class_weights / class_weights.sum() * 3  # Normalize
6
7# Use in loss
8loss = F.cross_entropy(logits, labels, weight=class_weights)

Critical Class Importance

The Critical class (Class 2) is the most important for maintenance decisionsβ€”missing a critical state has severe consequences. Our loss design ensures the model does not ignore this minority class despite its lower frequency.


Summary

In this section, we designed the health classification head:

  1. Three health states: Healthy (RUL > 125), Degrading (50-125), Critical (≀50)
  2. Architecture: Two-layer MLP (256 β†’ 64 β†’ 3)
  3. Output: Raw logits (softmax in loss function)
  4. Class imbalance: Addressed via focal loss or class weights
  5. Parameters: ~17K
PropertyValue
Input dimension256
Hidden dimension64
Output dimension3 (classes)
Class 0 (Healthy)RUL > 125
Class 1 (Degrading)50 < RUL ≀ 125
Class 2 (Critical)RUL ≀ 50
Total parameters16,643
Looking Ahead: We have designed both prediction heads. The next section assembles the complete modelβ€”connecting encoder, heads, and showing the full forward pass from raw sensors to dual predictions.

With both heads designed, we now assemble the complete AMNL model.