Chapter 11
15 min read
Section 55 of 104

Focal Loss for Imbalanced Health States

Advanced Loss Components

Learning Objectives

By the end of this section, you will:

  1. Understand health state class imbalance in C-MAPSS data
  2. Derive focal loss and its focusing mechanism
  3. Compare focal loss with class weighting
  4. Tune the focusing parameter γ for optimal performance
  5. Implement focal cross-entropy in PyTorch
Why This Matters: Health classification suffers from class imbalance—most samples are healthy or mildly degraded, while critical samples are rare. Standard cross-entropy is dominated by the majority class. Focal loss down-weights easy (well-classified) examples to focus learning on hard (minority) cases.

Health State Class Imbalance

The health classification task has inherent class imbalance.

Class Distribution

Recall the health class definitions:

ClassNameRUL RangeTypical Distribution
0HealthyRUL > 125~35-40%
1Degrading50 < RUL ≤ 125~35-40%
2CriticalRUL ≤ 50~20-30%

Dataset-Specific Distributions

DatasetHealthy (%)Degrading (%)Critical (%)
FD00138.236.525.3
FD00235.838.126.1
FD00340.135.224.7
FD00436.437.825.8

While not severely imbalanced, the critical class (most important for maintenance decisions) is underrepresented.

Impact on Standard Cross-Entropy


Focal Loss Theory

Focal loss (Lin et al., 2017) addresses class imbalance by down-weighting easy examples.

Standard Cross-Entropy

CE(p,y)=log(py)\text{CE}(p, y) = -\log(p_y)

Where pyp_y is the predicted probability for the correct class y.

Focal Loss Formulation

FL(p,y)=(1py)γlog(py)\text{FL}(p, y) = -(1 - p_y)^\gamma \log(p_y)

Where:

  • γ0\gamma \geq 0: Focusing parameter
  • (1py)γ(1 - p_y)^\gamma: Modulating factor

How the Modulating Factor Works

Modulating Factor Visualization

📝text
1Modulating Factor (1 - p)^γ vs. Confidence p:
2
3Factor
4  1.0 ─┤●
5      │ ╲
6  0.8 ─┤  ╲ γ=1
7      │   ╲  ╲
8  0.6 ─┤    ╲  ╲
9      │     ╲  ╲ γ=2
10  0.4 ─┤      ╲  ╲
11      │       ╲  ╲ γ=3
12  0.2 ─┤        ╲__╲
13      │           ╲╲
14  0.0 ─┼──────────────●
15      └──┬──┬──┬──┬──┬──
16        0.0 0.2 0.4 0.6 0.8 1.0
17              Confidence (p)
18
19Higher γ → Stronger down-weighting of easy examples

Class Weighting Strategies

Focal loss can be combined with class balancing weights.

Inverse Frequency Weighting

αc=NCNc\alpha_c = \frac{N}{C \cdot N_c}

Where:

  • NN: Total number of samples
  • CC: Number of classes
  • NcN_c: Number of samples in class c

Computing Class Weights for C-MAPSS

🐍python
1# Example: FD001 class distribution
2class_counts = {0: 38200, 1: 36500, 2: 25300}  # Approximate
3total = sum(class_counts.values())
4n_classes = 3
5
6# Inverse frequency weights
7weights = {
8    c: total / (n_classes * count)
9    for c, count in class_counts.items()
10}
11
12# Result: {0: 0.87, 1: 0.91, 2: 1.32}
13# Critical class gets ~1.5× weight vs. healthy

Combined Focal + Class Weights

FLα(p,y)=αy(1py)γlog(py)\text{FL}_{\alpha}(p, y) = -\alpha_y (1 - p_y)^\gamma \log(p_y)

This combines the focusing mechanism (reduce easy examples) with class balancing (boost minority classes).

StrategyEffectWhen to Use
Focal only (α=1)Down-weight easy examplesModerate imbalance
Class weights only (γ=0)Up-weight minority classesSevere imbalance
CombinedBoth effectsImbalance + easy/hard split

Implementation

Complete PyTorch implementation of focal cross-entropy.

Basic Focal Loss

🐍python
1class FocalLoss(nn.Module):
2    """
3    Focal Loss for multi-class classification.
4
5    Reduces the contribution of easy examples to focus learning
6    on hard, misclassified examples.
7
8    Args:
9        gamma: Focusing parameter (default 2.0)
10        alpha: Optional class weights (default None = uniform)
11        reduction: 'mean', 'sum', or 'none'
12    """
13
14    def __init__(
15        self,
16        gamma: float = 2.0,
17        alpha: Optional[torch.Tensor] = None,
18        reduction: str = "mean"
19    ):
20        super().__init__()
21        self.gamma = gamma
22        self.alpha = alpha
23        self.reduction = reduction
24
25    def forward(
26        self,
27        logits: torch.Tensor,
28        targets: torch.Tensor
29    ) -> torch.Tensor:
30        """
31        Compute focal loss.
32
33        Args:
34            logits: Raw predictions, shape (batch, num_classes)
35            targets: Class indices, shape (batch,)
36
37        Returns:
38            Focal loss value
39        """
40        # Compute softmax probabilities
41        probs = F.softmax(logits, dim=1)
42
43        # Get probability of correct class for each sample
44        # Shape: (batch,)
45        p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
46
47        # Compute focal weight: (1 - p_t)^gamma
48        focal_weight = (1 - p_t) ** self.gamma
49
50        # Compute cross-entropy: -log(p_t)
51        ce_loss = -torch.log(p_t + 1e-8)
52
53        # Apply focal weight
54        focal_loss = focal_weight * ce_loss
55
56        # Apply class weights if provided
57        if self.alpha is not None:
58            alpha_t = self.alpha.to(logits.device).gather(0, targets)
59            focal_loss = alpha_t * focal_loss
60
61        # Reduction
62        if self.reduction == "mean":
63            return focal_loss.mean()
64        elif self.reduction == "sum":
65            return focal_loss.sum()
66        else:
67            return focal_loss

With Automatic Class Weight Computation

🐍python
1class HealthFocalLoss(nn.Module):
2    """
3    Focal loss specialized for health classification.
4
5    Automatically computes class weights from training data
6    and applies focal modulation.
7
8    Args:
9        gamma: Focusing parameter
10        class_counts: Number of samples per class [healthy, degrading, critical]
11    """
12
13    def __init__(
14        self,
15        gamma: float = 2.0,
16        class_counts: Optional[List[int]] = None
17    ):
18        super().__init__()
19        self.gamma = gamma
20
21        # Compute class weights if counts provided
22        if class_counts is not None:
23            total = sum(class_counts)
24            n_classes = len(class_counts)
25            weights = [
26                total / (n_classes * count)
27                for count in class_counts
28            ]
29            self.register_buffer(
30                "alpha",
31                torch.tensor(weights, dtype=torch.float32)
32            )
33        else:
34            # Default weights for typical C-MAPSS distribution
35            self.register_buffer(
36                "alpha",
37                torch.tensor([1.0, 1.1, 1.4], dtype=torch.float32)
38            )
39
40    def forward(
41        self,
42        logits: torch.Tensor,
43        targets: torch.Tensor
44    ) -> torch.Tensor:
45        probs = F.softmax(logits, dim=1)
46        p_t = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
47
48        # Focal weight
49        focal_weight = (1 - p_t) ** self.gamma
50
51        # Cross-entropy
52        ce_loss = -torch.log(p_t + 1e-8)
53
54        # Class weights
55        alpha_t = self.alpha.gather(0, targets)
56
57        # Combined loss
58        loss = alpha_t * focal_weight * ce_loss
59
60        return loss.mean()

Choosing γ

γ ValueBehaviorRecommendation
0Standard cross-entropyBaseline
0.5Mild focusingLight imbalance
1.0Moderate focusingModerate imbalance
2.0Standard focal lossGeneral use (recommended)
3.0Strong focusingSevere imbalance
5.0Very strong focusingExtreme cases only

Hyperparameter Selection

Start with γ = 2.0 (the original paper value). If critical class recall is too low, increase γ. If training becomes unstable, decrease γ or add class weights.


Summary

In this section, we covered focal loss for health classification:

  1. Problem: Critical class underrepresented (20-25%)
  2. Focal loss: Down-weights easy examples with (1p)γ(1-p)^\gamma
  3. Effect: Shifts learning focus to hard/minority cases
  4. Combined: Focal + class weights for dual effect
  5. Default: γ = 2.0 with inverse-frequency class weights
ComponentValue
Focusing parameter γ2.0
Healthy class weight~1.0
Degrading class weight~1.1
Critical class weight~1.4
Looking Ahead: We have specialized losses for both RUL (weighted MSE, asymmetric) and health (focal). The next section shows how to combine multiple RUL loss components into a unified objective.

With focal loss understood, we combine multiple loss components for RUL.