Chapter 9
18 min read
Section 43 of 104

Uncertainty Weighting (Kendall et al.)

Traditional Multi-Task Loss Functions

Learning Objectives

By the end of this section, you will:

  1. Understand homoscedastic uncertainty in multi-task learning
  2. Derive task weights from maximum likelihood
  3. Implement uncertainty weighting with learnable parameters
  4. Recognize the connection between uncertainty and loss scale
  5. Identify limitations of this approach
Why This Matters: Kendall et al. (2018) introduced a principled approach to multi-task weighting based on task uncertainty. Instead of manually tuning weights, the model learns them by estimating how "noisy" each task is. This was a major advance over fixed weights and influenced many subsequent methods.

Uncertainty-Based Motivation

The key insight is that task weights should relate to task uncertainty.

Types of Uncertainty

TypeAlso CalledSourceReducible?
AleatoricData uncertaintyInherent noise in dataNo
EpistemicModel uncertaintyLimited training dataYes
HomoscedasticTask uncertaintyConstant across inputsNo
HeteroscedasticInput-dependentVaries with inputNo

Homoscedastic Uncertainty

Kendall et al. focus on homoscedastic uncertainty—task-level uncertainty that is constant across all inputs.

  • For RUL: The inherent unpredictability of exact failure time
  • For Health: The inherent ambiguity at state boundaries

The intuition: if a task has high intrinsic uncertainty, we should weight it less (its predictions are inherently noisy). If uncertainty is low, we should weight it more (its predictions are reliable).

Probabilistic Framing

Instead of directly predicting outputs, we model the likelihood of observations:

p(yfW(x))=N(fW(x),σ2)p(y | f^W(x)) = \mathcal{N}(f^W(x), \sigma^2)

Where σ2\sigma^2 is the homoscedastic variance—a learnable parameter representing task uncertainty.


Mathematical Derivation

The uncertainty-weighted loss emerges from maximum likelihood estimation.

Single Task (Regression)

For a regression task with Gaussian likelihood:

p(yfW(x),σ)=12πσexp((yfW(x))22σ2)p(y | f^W(x), \sigma) = \frac{1}{\sqrt{2\pi}\sigma} \exp\left(-\frac{(y - f^W(x))^2}{2\sigma^2}\right)

The negative log-likelihood is:

logp(yfW(x),σ)=(yfW(x))22σ2+logσ+const-\log p(y | f^W(x), \sigma) = \frac{(y - f^W(x))^2}{2\sigma^2} + \log \sigma + \text{const}

Multi-Task Likelihood

For two tasks with independent observations:

p(y1,y2fW(x))=p(y1fW(x))p(y2fW(x))p(y_1, y_2 | f^W(x)) = p(y_1 | f^W(x)) \cdot p(y_2 | f^W(x))

Taking the negative log:

L=12σ12L1+12σ22L2+logσ1+logσ2\mathcal{L} = \frac{1}{2\sigma_1^2}\mathcal{L}_1 + \frac{1}{2\sigma_2^2}\mathcal{L}_2 + \log \sigma_1 + \log \sigma_2

Practical Formulation

For numerical stability, we parameterize using s=logσ2s = \log \sigma^2:

L=12es1L1+12es2L2+12s1+12s2\mathcal{L} = \frac{1}{2}e^{-s_1}\mathcal{L}_1 + \frac{1}{2}e^{-s_2}\mathcal{L}_2 + \frac{1}{2}s_1 + \frac{1}{2}s_2

Where:

  • si=logσi2s_i = \log \sigma_i^2: Log-variance (learnable parameter)
  • esi=1σi2e^{-s_i} = \frac{1}{\sigma_i^2}: Precision (inverse variance)
  • 12si\frac{1}{2}s_i: Regularization term

Implementation

The uncertainty-weighted loss is straightforward to implement.

PyTorch Implementation

Uncertainty Weighted Loss
🐍losses/uncertainty_loss.py
14Learnable Log-Variance (RUL)

Initialized to 0, meaning σ² = exp(0) = 1.0 initially. This parameter learns the inherent uncertainty of RUL predictions.

EXAMPLE
# Initialization: log_var_rul = tensor([0.0])
σ²_rul = exp(0.0) = 1.0  # Initial variance
precision_rul = 1/σ² = 1.0  # Initial precision (weight)

# After training: log_var_rul might become 5.2
σ²_rul = exp(5.2) ≈ 181  # High variance learned
precision_rul = exp(-5.2) ≈ 0.0055  # Low weight on RUL
15Learnable Log-Variance (Health)

Same initialization for health task. Lower variance learned here typically means higher weight on health classification.

EXAMPLE
# After training: log_var_health might become -0.3
σ²_health = exp(-0.3) ≈ 0.74  # Low variance
precision_health = exp(0.3) ≈ 1.35  # Higher weight on health
25Squeeze RUL Predictions

Removes extra dimensions from predictions to match target shape for MSE computation.

EXAMPLE
# BEFORE squeeze: rul_pred.shape = (32, 1)
rul_pred = [[47.3], [92.1], ..., [15.8]]

# AFTER squeeze: rul_pred.shape = (32,)
rul_pred = [47.3, 92.1, ..., 15.8]

# rul_target.shape = (32,)
rul_target = [50.0, 89.0, ..., 12.0]
32Precision from Log-Variance

Convert log-variance to precision (inverse variance) for weighting. The negative sign in exp(-s) inverts the scale.

EXAMPLE
# If log_var_rul = 5.2 (learned high uncertainty)
precision_rul = exp(-5.2) = 0.0055

# If log_var_rul = -0.5 (learned low uncertainty)
precision_rul = exp(0.5) = 1.65

# Precision ∝ 1/uncertainty → high uncertainty = low weight
35Weighted Loss Combination

Final loss combines precision-weighted task losses plus regularization terms that prevent variance from exploding.

EXAMPLE
# Example computation:
loss_rul = 400.0  # MSE on RUL
loss_health = 0.36  # Cross-entropy

precision_rul = 0.0055  # Low (high variance)
precision_health = 1.35  # High (low variance)
log_var_rul = 5.2
log_var_health = -0.3

total = 0.5 * 0.0055 * 400.0   +  # 1.1 (RUL contrib)
        0.5 * 5.2               +  # 2.6 (RUL regularizer)
        0.5 * 1.35 * 0.36       +  # 0.24 (Health contrib)
        0.5 * (-0.3)               # -0.15 (Health regularizer)
      = 3.79
36 lines without explanation
1class UncertaintyWeightedLoss(nn.Module):
2    """
3    Uncertainty weighting from Kendall et al. (2018).
4
5    Learns task-specific log-variances that weight the losses.
6    Based on "Multi-Task Learning Using Uncertainty to Weigh
7    Losses for Scene Geometry and Semantics"
8    """
9
10    def __init__(self):
11        super().__init__()
12        # Learnable log-variances (initialized to 0 → σ² = 1)
13        self.log_var_rul = nn.Parameter(torch.zeros(1))
14        self.log_var_health = nn.Parameter(torch.zeros(1))
15
16    def forward(
17        self,
18        rul_pred: torch.Tensor,
19        rul_target: torch.Tensor,
20        health_logits: torch.Tensor,
21        health_target: torch.Tensor
22    ) -> torch.Tensor:
23        # Individual losses
24        loss_rul = F.mse_loss(rul_pred.squeeze(), rul_target)
25        loss_health = F.cross_entropy(health_logits, health_target)
26
27        # Precision-weighted losses with regularization
28        # L = (1/2σ²)L_task + (1/2)log(σ²)
29        # L = (1/2)exp(-s)L_task + (1/2)s  where s = log(σ²)
30
31        precision_rul = torch.exp(-self.log_var_rul)
32        precision_health = torch.exp(-self.log_var_health)
33
34        total_loss = (
35            0.5 * precision_rul * loss_rul +
36            0.5 * self.log_var_rul +
37            0.5 * precision_health * loss_health +
38            0.5 * self.log_var_health
39        )
40
41        return total_loss

Training Dynamics

During training, the log-variances adjust automatically:

📝text
1Training progression:
2
3Epoch 1:  s_rul = 0.0,  s_health = 0.0
4          weights: (1.0, 1.0)
5
6Epoch 50: s_rul = 5.2,  s_health = -0.3
7          weights: (0.005, 1.35)
8
9Interpretation:
10  - RUL has high uncertainty (σ²_rul ≈ 180) → low weight
11  - Health has low uncertainty (σ²_health ≈ 0.74) → high weight
12  - Model learned that health predictions are more reliable

No Manual Tuning

The key advantage: we did not manually set λ₁ = 0.005 and λ₂ = 1.35. These values emerged from optimizing the likelihood, automatically adapting to the loss scales.


Limitations

Despite its principled foundation, uncertainty weighting has limitations.

Assumption Violations

  • Gaussian assumption: RUL errors may not be Gaussian (heavy-tailed, asymmetric)
  • Homoscedasticity: Uncertainty may actually vary with RUL level (heteroscedastic)
  • Independence: Task errors may be correlated (both depend on degradation)

Optimization Issues

IssueDescription
Local minimaLog-variances can get stuck at poor values
Initialization sensitivityStarting values affect convergence
Slow adaptationLog-variances change slowly via gradient descent
Regularization trade-offlog(σ) term may be too weak or strong

RUL-Specific Problems

Empirical Results on C-MAPSS

Our experiments show uncertainty weighting underperforms:

MethodFD001 RMSEFD002 RMSE
Fixed weights (tuned)11.816.2
Uncertainty weighting12.417.1
AMNL (our method)10.813.9

Summary

In this section, we examined uncertainty weighting:

  1. Core idea: Weight tasks by inverse uncertainty
  2. Derivation: Maximum likelihood with learnable variance
  3. Formula: L=12σ12L1+12σ22L2+logσ1σ2\mathcal{L} = \frac{1}{2\sigma_1^2}\mathcal{L}_1 + \frac{1}{2\sigma_2^2}\mathcal{L}_2 + \log\sigma_1\sigma_2
  4. Advantage: No manual weight tuning
  5. Limitation: Confuses loss scale with uncertainty
AspectValue
Learnable parameters2 (log-variances)
Tuning requiredNone (learns from data)
AssumptionGaussian, homoscedastic errors
RUL suitabilityLimited (scale/uncertainty confusion)
Looking Ahead: Uncertainty weighting addresses loss scale automatically but has issues with changing loss dynamics. The next section introduces GradNorm—an approach that directly balances gradient magnitudes rather than loss values.

With uncertainty weighting understood, we examine gradient-based balancing methods.