Chapter 9
15 min read
Section 45 of 104

Dynamic Weight Average (DWA)

Traditional Multi-Task Loss Functions

Learning Objectives

By the end of this section, you will:

  1. Understand DWA's loss-rate approach to task weighting
  2. Implement the softmax-based weight computation
  3. Tune the temperature hyperparameter for weight sensitivity
  4. Compare DWA with GradNorm in terms of cost and performance
  5. Identify DWA's limitations for RUL prediction
Why This Matters: Dynamic Weight Average (Liu et al., 2019) offers a computationally cheap alternative to GradNorm. Instead of computing per-task gradients, DWA adjusts weights based on how quickly each task's loss is decreasing—a proxy for learning difficulty that requires no additional backward passes.

DWA Motivation

DWA is motivated by a simple observation: tasks that are improving slowly likely need more focus.

The Core Intuition

Track the rate of loss decrease for each task:

wi(t)Li(t1)Li(t2)w_i(t) \propto \frac{\mathcal{L}_i(t-1)}{\mathcal{L}_i(t-2)}
  • If loss decreased (ratio < 1): task is learning well → lower weight
  • If loss increased (ratio > 1): task is struggling → higher weight
  • If loss unchanged (ratio = 1): maintain current focus

Comparison with GradNorm

AspectGradNormDWA
SignalGradient magnitudesLoss change rates
ComputationPer-task gradientsJust loss values
Overhead~T× backward passesNegligible
MemoryO(T × params)O(T)
SensitivityPer-stepSmoothed over epochs

No Gradient Computation

DWA's key advantage: it only needs loss values from consecutive epochs, not per-task gradients. This makes it as fast as standard training while still adapting weights dynamically.


Mathematical Formulation

DWA computes weights using softmax over loss ratios.

Loss Rate Computation

Define the relative descending rate for task i:

ri(t)=Li(t1)Li(t2)r_i(t) = \frac{\mathcal{L}_i(t-1)}{\mathcal{L}_i(t-2)}

Where:

  • Li(t1)\mathcal{L}_i(t-1): Loss at previous epoch
  • Li(t2)\mathcal{L}_i(t-2): Loss at epoch before that
  • ri(t)<1r_i(t) < 1: Loss decreased (task improving)
  • ri(t)>1r_i(t) > 1: Loss increased (task struggling)

Softmax Weighting

Convert rates to weights using softmax with temperature T:

wi(t)=Kexp(ri(t)/T)j=1Kexp(rj(t)/T)w_i(t) = \frac{K \exp(r_i(t) / T)}{\sum_{j=1}^{K} \exp(r_j(t) / T)}

Where:

  • KK: Number of tasks (ensures weights sum to K)
  • TT: Temperature hyperparameter
  • Large T → weights approach uniform (less sensitive)
  • Small T → weights more extreme (more sensitive)

Implementation

DWA is straightforward to implement.

PyTorch Implementation

Dynamic Weight Average Loss
🐍losses/dwa_loss.py
19Loss History Storage

Tracks losses from the previous two epochs. Initialized to 1.0 to avoid division issues at start.

EXAMPLE
# Initialization:
loss_history = {
    't-2': [1.0, 1.0],  # Epoch (t-2) losses
    't-1': [1.0, 1.0]   # Epoch (t-1) losses
}

# After epoch 5:
loss_history = {
    't-2': [800.0, 1.1],   # Epoch 4 losses
    't-1': [600.0, 0.95]   # Epoch 5 losses
}
32Descent Rate Computation

Rate = current/previous loss. Values < 1 mean improvement, > 1 means degradation.

EXAMPLE
# Loss history:
't-2': [800.0, 1.1]   # Previous epoch
't-1': [600.0, 0.95]  # Current epoch

# Compute rates:
rate_rul = 600.0 / 800.0 = 0.75    # 25% improvement
rate_health = 0.95 / 1.1 = 0.86   # 14% improvement

# RUL improving faster → gets lower weight
# Health improving slower → gets higher weight
39Softmax Temperature Scaling

Temperature controls sensitivity. High T → smoother weights, Low T → more extreme weights.

EXAMPLE
# With rates = [0.75, 0.86]:

# Temperature = 2.0 (default, smooth):
exp_rates = exp([0.75/2, 0.86/2])
          = exp([0.375, 0.43])
          = [1.455, 1.537]
weights = 2 * [1.455, 1.537] / sum
        = [0.97, 1.03]  # Modest difference

# Temperature = 0.5 (aggressive):
weights ≈ [0.75, 1.25]  # More extreme
41Weight Normalization

Weights sum to num_tasks (=2), not 1. This preserves the original loss scale during training.

EXAMPLE
# Softmax normalization:
exp_rates = [1.455, 1.537]
exp_sum = 1.455 + 1.537 = 2.992

# Multiply by num_tasks to sum to 2:
weights = 2 * exp_rates / exp_sum
        = 2 * [0.486, 0.514]
        = [0.97, 1.03]

# Sum: 0.97 + 1.03 = 2.0 ✓
45History Update (Sliding Window)

Shift loss history forward: current becomes previous, new losses will be recorded next epoch.

EXAMPLE
# Before update:
loss_history = {
    't-2': [800.0, 1.1],
    't-1': [600.0, 0.95]
}
current_losses = [500.0, 0.85]

# After update:
loss_history = {
    't-2': [600.0, 0.95],   # Was t-1
    't-1': [500.0, 0.85]    # New current
}
60Weighted Loss Computation

Apply computed weights to task losses. No gradient computation needed, just scalar multiplication.

EXAMPLE
# Inputs:
losses = [tensor(500.0), tensor(0.85)]
weights = [0.97, 1.03]

# Weighted sum:
total_loss = 0.97 * 500.0 + 1.03 * 0.85
           = 485.0 + 0.88
           = 485.88

# Compare to equal weights (1.0, 1.0):
# 1.0 * 500.0 + 1.0 * 0.85 = 500.85
59 lines without explanation
1class DWALoss(nn.Module):
2    """
3    Dynamic Weight Average (Liu et al., 2019)
4
5    Adjusts task weights based on loss descent rates.
6    No additional gradient computation required.
7    """
8
9    def __init__(
10        self,
11        num_tasks: int = 2,
12        temperature: float = 2.0
13    ):
14        super().__init__()
15        self.num_tasks = num_tasks
16        self.temperature = temperature
17
18        # Loss history (last two epochs)
19        self.loss_history = {
20            't-2': [1.0] * num_tasks,
21            't-1': [1.0] * num_tasks
22        }
23        self.weights = [1.0] * num_tasks
24
25    def update_weights(self, current_losses: list):
26        """
27        Update weights based on loss descent rates.
28        Call once per epoch after computing average losses.
29        """
30        # Compute descent rates
31        rates = []
32        for i in range(self.num_tasks):
33            rate = self.loss_history['t-1'][i] / (
34                self.loss_history['t-2'][i] + 1e-8
35            )
36            rates.append(rate)
37
38        # Softmax with temperature
39        rates = torch.tensor(rates)
40        exp_rates = torch.exp(rates / self.temperature)
41        self.weights = (
42            self.num_tasks * exp_rates / exp_rates.sum()
43        ).tolist()
44
45        # Update history
46        self.loss_history['t-2'] = self.loss_history['t-1'].copy()
47        self.loss_history['t-1'] = [l.item() for l in current_losses]
48
49    def forward(
50        self,
51        losses: list
52    ) -> torch.Tensor:
53        """
54        Compute weighted loss.
55
56        Args:
57            losses: List of individual task losses
58
59        Returns:
60            Weighted sum of losses
61        """
62        total_loss = sum(
63            w * l for w, l in zip(self.weights, losses)
64        )
65        return total_loss

Temperature Selection

The temperature T controls weight sensitivity:

TemperatureBehaviorUse Case
T → 0Winner-take-all (focus on hardest)Aggressive balancing
T = 1Proportional to ratesModerate adaptation
T = 2Smoothed weights (recommended)Stable training
T → ∞Uniform weightsNo adaptation

Default Choice

T = 2 is typically used as it provides good balance between responsiveness and stability. Lower values can cause oscillating weights; higher values reduce the benefit of adaptation.


Comparison with Other Methods

How does DWA compare to the methods we have studied?

Method Overview

MethodWeights FromLearnableOverheadHyperparams
FixedManualNoNoneT weights
UncertaintyLog-varianceYesMinimalNone
GradNormGradientsYes~T×α, lr
DWALoss ratesNo*NoneT (temperature)

DWA Weights

*DWA weights are computed from loss history, not learned via backpropagation. They adapt dynamically but are not gradient-updated parameters.

Pros and Cons

AspectProCon
SimplicityEasy to implementLess principled than others
CostNo overhead-
AdaptationResponds to training dynamicsLags by two epochs
SensitivityTemperature tunableCan be too smooth or noisy

Empirical Results on C-MAPSS

MethodFD001FD002FD003FD004Time
Fixed11.816.212.519.81.0×
Uncertainty12.417.113.120.51.0×
GradNorm11.515.812.119.22.8×
DWA11.615.912.319.31.0×
AMNL10.813.911.217.41.0×

DWA performs similarly to GradNorm but without the computational cost. However, both are significantly outperformed by AMNL.


Summary

In this section, we examined Dynamic Weight Average:

  1. Core idea: Weight tasks by inverse learning rate
  2. Computation: rᵢ = Lᵢ(t-1) / Lᵢ(t-2)
  3. Weights: Softmax over rates with temperature
  4. Advantage: No computational overhead
  5. Limitation: Two-epoch lag, temperature sensitivity
PropertyValue
Weight updateOnce per epoch
MemoryO(T) for loss history
HyperparameterTemperature T
Default T2.0
RUL suitabilityModerate (lag and noise issues)
Looking Ahead: We have surveyed four multi-task weighting methods: fixed, uncertainty, GradNorm, and DWA. The final section synthesizes why all of these methods fail for RUL prediction—setting the stage for AMNL's novel approach in Chapter 10.

With DWA understood, we analyze why existing methods fall short for RUL.