Learning Objectives
By the end of this section, you will:
- Understand health state class imbalance in C-MAPSS data
- Derive focal loss and its focusing mechanism
- Compare focal loss with class weighting
- Tune the focusing parameter γ for optimal performance
- Implement focal cross-entropy in PyTorch
Why This Matters: Health classification suffers from class imbalance—most samples are healthy or mildly degraded, while critical samples are rare. Standard cross-entropy is dominated by the majority class. Focal loss down-weights easy (well-classified) examples to focus learning on hard (minority) cases.
Health State Class Imbalance
The health classification task has inherent class imbalance.
Class Distribution
Recall the health class definitions:
| Class | Name | RUL Range | Typical Distribution |
|---|---|---|---|
| 0 | Healthy | RUL > 125 | ~35-40% |
| 1 | Degrading | 50 < RUL ≤ 125 | ~35-40% |
| 2 | Critical | RUL ≤ 50 | ~20-30% |
Dataset-Specific Distributions
| Dataset | Healthy (%) | Degrading (%) | Critical (%) |
|---|---|---|---|
| FD001 | 38.2 | 36.5 | 25.3 |
| FD002 | 35.8 | 38.1 | 26.1 |
| FD003 | 40.1 | 35.2 | 24.7 |
| FD004 | 36.4 | 37.8 | 25.8 |
While not severely imbalanced, the critical class (most important for maintenance decisions) is underrepresented.
Impact on Standard Cross-Entropy
Focal Loss Theory
Focal loss (Lin et al., 2017) addresses class imbalance by down-weighting easy examples.
Standard Cross-Entropy
Where is the predicted probability for the correct class y.
Focal Loss Formulation
Where:
- : Focusing parameter
- : Modulating factor
How the Modulating Factor Works
Modulating Factor Visualization
1Modulating Factor (1 - p)^γ vs. Confidence p:
2
3Factor
4 1.0 ─┤●
5 │ ╲
6 0.8 ─┤ ╲ γ=1
7 │ ╲ ╲
8 0.6 ─┤ ╲ ╲
9 │ ╲ ╲ γ=2
10 0.4 ─┤ ╲ ╲
11 │ ╲ ╲ γ=3
12 0.2 ─┤ ╲__╲
13 │ ╲╲
14 0.0 ─┼──────────────●
15 └──┬──┬──┬──┬──┬──
16 0.0 0.2 0.4 0.6 0.8 1.0
17 Confidence (p)
18
19Higher γ → Stronger down-weighting of easy examplesClass Weighting Strategies
Focal loss can be combined with class balancing weights.
Inverse Frequency Weighting
Where:
- : Total number of samples
- : Number of classes
- : Number of samples in class c
Computing Class Weights for C-MAPSS
1# Example: FD001 class distribution
2class_counts = {0: 38200, 1: 36500, 2: 25300} # Approximate
3total = sum(class_counts.values())
4n_classes = 3
5
6# Inverse frequency weights
7weights = {
8 c: total / (n_classes * count)
9 for c, count in class_counts.items()
10}
11
12# Result: {0: 0.87, 1: 0.91, 2: 1.32}
13# Critical class gets ~1.5× weight vs. healthyCombined Focal + Class Weights
This combines the focusing mechanism (reduce easy examples) with class balancing (boost minority classes).
| Strategy | Effect | When to Use |
|---|---|---|
| Focal only (α=1) | Down-weight easy examples | Moderate imbalance |
| Class weights only (γ=0) | Up-weight minority classes | Severe imbalance |
| Combined | Both effects | Imbalance + easy/hard split |
Implementation
Complete PyTorch implementation of focal cross-entropy.
Basic Focal Loss
1class FocalLoss(nn.Module):
2 """
3 Focal Loss for multi-class classification.
4
5 Reduces the contribution of easy examples to focus learning
6 on hard, misclassified examples.
7
8 Args:
9 gamma: Focusing parameter (default 2.0)
10 alpha: Optional class weights (default None = uniform)
11 reduction: 'mean', 'sum', or 'none'
12 """
13
14 def __init__(
15 self,
16 gamma: float = 2.0,
17 alpha: Optional[torch.Tensor] = None,
18 reduction: str = "mean"
19 ):
20 super().__init__()
21 self.gamma = gamma
22 self.alpha = alpha
23 self.reduction = reduction
24
25 def forward(
26 self,
27 logits: torch.Tensor,
28 targets: torch.Tensor
29 ) -> torch.Tensor:
30 """
31 Compute focal loss.
32
33 Args:
34 logits: Raw predictions, shape (batch, num_classes)
35 targets: Class indices, shape (batch,)
36
37 Returns:
38 Focal loss value
39 """
40 # Compute softmax probabilities
41 probs = F.softmax(logits, dim=1)
42
43 # Get probability of correct class for each sample
44 # Shape: (batch,)
45 p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
46
47 # Compute focal weight: (1 - p_t)^gamma
48 focal_weight = (1 - p_t) ** self.gamma
49
50 # Compute cross-entropy: -log(p_t)
51 ce_loss = -torch.log(p_t + 1e-8)
52
53 # Apply focal weight
54 focal_loss = focal_weight * ce_loss
55
56 # Apply class weights if provided
57 if self.alpha is not None:
58 alpha_t = self.alpha.to(logits.device).gather(0, targets)
59 focal_loss = alpha_t * focal_loss
60
61 # Reduction
62 if self.reduction == "mean":
63 return focal_loss.mean()
64 elif self.reduction == "sum":
65 return focal_loss.sum()
66 else:
67 return focal_lossWith Automatic Class Weight Computation
1class HealthFocalLoss(nn.Module):
2 """
3 Focal loss specialized for health classification.
4
5 Automatically computes class weights from training data
6 and applies focal modulation.
7
8 Args:
9 gamma: Focusing parameter
10 class_counts: Number of samples per class [healthy, degrading, critical]
11 """
12
13 def __init__(
14 self,
15 gamma: float = 2.0,
16 class_counts: Optional[List[int]] = None
17 ):
18 super().__init__()
19 self.gamma = gamma
20
21 # Compute class weights if counts provided
22 if class_counts is not None:
23 total = sum(class_counts)
24 n_classes = len(class_counts)
25 weights = [
26 total / (n_classes * count)
27 for count in class_counts
28 ]
29 self.register_buffer(
30 "alpha",
31 torch.tensor(weights, dtype=torch.float32)
32 )
33 else:
34 # Default weights for typical C-MAPSS distribution
35 self.register_buffer(
36 "alpha",
37 torch.tensor([1.0, 1.1, 1.4], dtype=torch.float32)
38 )
39
40 def forward(
41 self,
42 logits: torch.Tensor,
43 targets: torch.Tensor
44 ) -> torch.Tensor:
45 probs = F.softmax(logits, dim=1)
46 p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
47
48 # Focal weight
49 focal_weight = (1 - p_t) ** self.gamma
50
51 # Cross-entropy
52 ce_loss = -torch.log(p_t + 1e-8)
53
54 # Class weights
55 alpha_t = self.alpha.gather(0, targets)
56
57 # Combined loss
58 loss = alpha_t * focal_weight * ce_loss
59
60 return loss.mean()Choosing γ
| γ Value | Behavior | Recommendation |
|---|---|---|
| 0 | Standard cross-entropy | Baseline |
| 0.5 | Mild focusing | Light imbalance |
| 1.0 | Moderate focusing | Moderate imbalance |
| 2.0 | Standard focal loss | General use (recommended) |
| 3.0 | Strong focusing | Severe imbalance |
| 5.0 | Very strong focusing | Extreme cases only |
Hyperparameter Selection
Start with γ = 2.0 (the original paper value). If critical class recall is too low, increase γ. If training becomes unstable, decrease γ or add class weights.
Summary
In this section, we covered focal loss for health classification:
- Problem: Critical class underrepresented (20-25%)
- Focal loss: Down-weights easy examples with
- Effect: Shifts learning focus to hard/minority cases
- Combined: Focal + class weights for dual effect
- Default: γ = 2.0 with inverse-frequency class weights
| Component | Value |
|---|---|
| Focusing parameter γ | 2.0 |
| Healthy class weight | ~1.0 |
| Degrading class weight | ~1.1 |
| Critical class weight | ~1.4 |
Looking Ahead: We have specialized losses for both RUL (weighted MSE, asymmetric) and health (focal). The next section shows how to combine multiple RUL loss components into a unified objective.
With focal loss understood, we combine multiple loss components for RUL.