Chapter 10
15 min read
Section 49 of 104

RUL Loss Component: Weighted MSE

AMNL: The Novel Loss Function

Learning Objectives

By the end of this section, you will:

  1. Understand why standard MSE is suboptimal for RUL
  2. Design the weight function that emphasizes critical samples
  3. Derive the weighted MSE formula
  4. Analyze gradient behavior with sample weighting
  5. Implement weighted MSE in PyTorch
Why This Matters: Not all RUL prediction errors are equal. An error of 10 cycles when true RUL is 100 is less critical than an error of 10 cycles when true RUL is 20. Weighted MSE ensures the model focuses on the samples that matter most—those approaching failure.

Why Weighted MSE?

Standard MSE treats all samples equally, but RUL prediction has asymmetric importance.

The Problem with Standard MSE

LMSE=1Ni=1N(yiy^i)2\mathcal{L}_{\text{MSE}} = \frac{1}{N}\sum_{i=1}^{N} (y_i - \hat{y}_i)^2

Standard MSE weights all prediction errors equally, regardless of the true RUL value.

Real-World Importance

True RULPrediction ErrorConsequence
100 cycles±10 cyclesMinor schedule adjustment
50 cycles±10 cyclesSignificant planning impact
20 cycles±10 cyclesCritical safety concern
5 cycles±10 cyclesPotential failure miss or false alarm

Errors at low RUL have far greater practical consequences than errors at high RUL.

The Solution

Weight samples inversely to their RUL: low-RUL samples get higher weights.

Lweighted=1iwii=1Nwi(yiy^i)2\mathcal{L}_{\text{weighted}} = \frac{1}{\sum_i w_i}\sum_{i=1}^{N} w_i (y_i - \hat{y}_i)^2

The Weight Function

We use a linear decay weight function based on true RUL.

Linear Decay Weights

wi=1+RmaxyiRmax=2yiRmaxw_i = 1 + \frac{R_{\max} - y_i}{R_{\max}} = 2 - \frac{y_i}{R_{\max}}

Where:

  • yiy_i: True RUL for sample i
  • Rmax=125R_{\max} = 125: Maximum RUL (cap)
  • wi[1,2]w_i \in [1, 2]: Weight range

Weight Distribution

📝text
1Weight vs. True RUL:
2
3Weight
4  2.0 ─┤●
5      │ ╲
6  1.8 ─┤  ╲
7      │   ╲
8  1.6 ─┤    ╲
9      │     ╲
10  1.4 ─┤      ╲
11      │       ╲
12  1.2 ─┤        ╲
13      │         ╲
14  1.0 ─┤          ╲────────────●
15      └──┬──┬──┬──┬──┬──┬──┬──┬
16         0  15 30 45 60 75 90 105 125
17                  True RUL
18
19Key points:
20  RUL = 0:   w = 2.0 (maximum weight)
21  RUL = 62:  w = 1.5 (midpoint)
22  RUL = 125: w = 1.0 (minimum weight)

Why Linear Decay?


Mathematical Formulation

The complete weighted MSE loss for RUL prediction.

Full Formula

LRUL=1i=1Nwii=1Nwi(yiy^i)2\mathcal{L}_{\text{RUL}} = \frac{1}{\sum_{i=1}^{N} w_i}\sum_{i=1}^{N} w_i (y_i - \hat{y}_i)^2

With weight function:

wi=2min(yi,Rmax)Rmaxw_i = 2 - \frac{\min(y_i, R_{\max})}{R_{\max}}

RUL Capping

We use min(yi,Rmax)\min(y_i, R_{\max}) to handle samples with RUL > 125. These samples are in the healthy phase where exact RUL is less meaningful, so we cap them at 125.

Implementation

Weighted MSE Loss
🐍losses/weighted_mse.py
18Clamp Target RUL

Cap target values at r_max (125). Samples with RUL > 125 are treated as RUL=125 for weight calculation, preventing negative weights.

EXAMPLE
# BEFORE clamp: target with outliers
target = tensor([100.0, 50.0, 20.0, 150.0, 200.0])

# AFTER torch.clamp(target, max=125.0):
capped_target = tensor([100.0, 50.0, 20.0, 125.0, 125.0])
#                                          ↑       ↑
#                                       capped at 125
19Linear Decay Weights

Compute weights using formula w = 2 - y/R_max. Low RUL gets weight ≈ 2, high RUL gets weight ≈ 1.

EXAMPLE
# With capped_target = [100.0, 50.0, 20.0, 125.0]
weights = 2.0 - capped_target / 125.0

# Step by step:
# 2.0 - 100/125 = 2.0 - 0.80 = 1.20
# 2.0 - 50/125  = 2.0 - 0.40 = 1.60
# 2.0 - 20/125  = 2.0 - 0.16 = 1.84
# 2.0 - 125/125 = 2.0 - 1.00 = 1.00

weights = tensor([1.20, 1.60, 1.84, 1.00])
# Critical samples (low RUL) → higher weight
22Squared Errors

Element-wise squared difference between predictions and targets. Standard MSE component before weighting.

EXAMPLE
# Predictions and targets:
pred   = tensor([95.0, 55.0, 30.0, 120.0])
target = tensor([100.0, 50.0, 20.0, 125.0])

# Squared errors:
squared_errors = (pred - target) ** 2
               = ([-5, 5, 10, -5]) ** 2
               = tensor([25.0, 25.0, 100.0, 25.0])
23Apply Weights

Multiply squared errors by sample weights. Critical samples (high weight) contribute more to total loss.

EXAMPLE
# Weights and squared errors:
weights = tensor([1.20, 1.60, 1.84, 1.00])
squared_errors = tensor([25.0, 25.0, 100.0, 25.0])

# Weighted errors:
weighted_errors = weights * squared_errors
                = tensor([30.0, 40.0, 184.0, 25.0])
#                              ↑
#                    Low RUL sample dominates!
26Normalize by Weight Sum

Divide by sum of weights (not count) for proper weighted average. This prevents weight magnitude from affecting loss scale.

EXAMPLE
# Weighted errors and weights:
weighted_errors = tensor([30.0, 40.0, 184.0, 25.0])
weights = tensor([1.20, 1.60, 1.84, 1.00])

# Normalized loss:
loss = weighted_errors.sum() / weights.sum()
     = 279.0 / 5.64
     = 49.5  # scalar

# Compare to simple mean: 279.0 / 4 = 69.75
# Weight normalization adjusts for weight magnitudes
23 lines without explanation
1def weighted_mse_loss(
2    pred: torch.Tensor,
3    target: torch.Tensor,
4    r_max: float = 125.0
5) -> torch.Tensor:
6    """
7    Weighted MSE loss with linear decay weights.
8
9    Args:
10        pred: Predicted RUL (batch,)
11        target: True RUL (batch,)
12        r_max: Maximum RUL for weight computation
13
14    Returns:
15        Weighted MSE loss (scalar)
16    """
17    # Compute weights: w = 2 - y/R_max
18    capped_target = torch.clamp(target, max=r_max)
19    weights = 2.0 - capped_target / r_max
20
21    # Weighted squared errors
22    squared_errors = (pred - target) ** 2
23    weighted_errors = weights * squared_errors
24
25    # Normalize by sum of weights
26    loss = weighted_errors.sum() / weights.sum()
27
28    return loss

Numerical Example


Gradient Analysis

Understanding the gradient helps us see how weighting affects learning.

Gradient Derivation

The gradient of weighted MSE with respect to prediction:

LRULy^i=2wi(y^iyi)jwj\frac{\partial \mathcal{L}_{\text{RUL}}}{\partial \hat{y}_i} = \frac{2 w_i (\hat{y}_i - y_i)}{\sum_j w_j}

Key observation:

  • Gradient magnitude is proportional to weight wiw_i
  • Low-RUL samples (high weight) produce larger gradients
  • Model receives stronger learning signal from critical cases

Gradient Comparison

For the same prediction error (e.g., 10 cycles off):

True RULWeightRelative Gradient
1251.001.0×
1001.201.2×
501.601.6×
251.801.8×
02.002.0×

Low-RUL samples contribute up to 2× larger gradients, ensuring the model prioritizes learning to predict accurately near failure.


Summary

In this section, we designed the weighted MSE loss for RUL:

  1. Motivation: Low-RUL errors are more critical than high-RUL errors
  2. Weight function: w=2y/Rmaxw = 2 - y/R_{\max}
  3. Weight range: [1.0, 2.0] (maximum 2× emphasis)
  4. Effect: Larger gradients for critical samples
  5. Normalization: Divide by sum of weights
ParameterValue
R_max125 cycles
Weight at RUL=02.0
Weight at RUL=1251.0
Gradient amplificationUp to 2×
Looking Ahead: With the RUL loss component defined, the next section describes the health classification loss—the cross-entropy component that provides regularization for the RUL prediction task.

Having defined the RUL loss, we now examine the health classification component.