Chapter 9
20 min read
Section 44 of 104

GradNorm: Gradient Normalization

Traditional Multi-Task Loss Functions

Learning Objectives

By the end of this section, you will:

  1. Understand gradient-based task balancing as an alternative to loss weighting
  2. Learn the GradNorm algorithm for equalizing gradient norms
  3. Implement training rate balancing to ensure equal learning progress
  4. Identify computational overhead of gradient-based methods
  5. Recognize limitations when applied to RUL prediction
Why This Matters: GradNorm (Chen et al., 2018) introduced a key insight: what matters for multi-task learning is not loss magnitude but gradient magnitude. By directly balancing gradients, we ensure each task contributes equally to parameter updates, regardless of loss scale.

Gradient-Based Motivation

The fundamental insight is that optimization happens through gradients, not losses.

Why Gradients Matter More Than Losses

Parameters are updated via gradient descent:

θt+1=θtηθL\theta_{t+1} = \theta_t - \eta \nabla_\theta \mathcal{L}

If one task has gradients 100× larger than another, it receives 100× more influence on the update—regardless of how we weight the losses.

Gradient Imbalance Example


GradNorm Algorithm

GradNorm adjusts task weights to equalize the gradient norms across tasks.

Core Idea

Define task weights wiw_i such that:

wiθLiGˉ\|w_i \nabla_\theta \mathcal{L}_i\| \approx \bar{G}

Where Gˉ\bar{G} is the average gradient norm across all tasks. This ensures each task contributes equally to parameter updates.

The GradNorm Objective

GradNorm learns weights by minimizing:

Lgrad=iGiWGˉ×riα\mathcal{L}_{\text{grad}} = \sum_i \left| G_i^W - \bar{G} \times r_i^{\alpha} \right|

Where:

  • GiW=wiWLiG_i^W = \|w_i \nabla_W \mathcal{L}_i\|: Weighted gradient norm for task i
  • Gˉ=E[GiW]\bar{G} = \mathbb{E}[G_i^W]: Average gradient norm
  • ri=L~i(t)/L~i(0)r_i = \tilde{\mathcal{L}}_i(t) / \tilde{\mathcal{L}}_i(0): Relative inverse training rate
  • α\alpha: Hyperparameter controlling training rate balance (typically 1.5)

Algorithm Steps

📝text
1GradNorm Algorithm:
2
3For each training step:
41. Compute task losses: L_1, L_2, ..., L_T
52. Compute weighted loss: L = Σ w_i L_i
63. Backward pass: compute gradients ∇_θ L
7
84. For gradient normalization update:
9   a. Compute G_i = ||w_i ∇_W L_i|| for each task
10      (W is the last shared layer weights)
11   b. Compute average: G̅ = mean(G_i)
12   c. Compute training rates: r_i = L̃_i(t) / L̃_i(0)
13   d. Compute target gradients: G_target_i = G̅ × (r_i)^α
14   e. Compute GradNorm loss: L_grad = Σ |G_i - G_target_i|
15   f. Update weights: w_i ← w_i - η_w ∇_{w_i} L_grad
16
175. Normalize weights: w_i ← T × w_i / Σ w_j

Implementation

GradNorm Loss
🐍losses/gradnorm_loss.py
15Learnable Task Weights

Initialized to [1.0, 1.0] for 2 tasks. These weights are updated to balance gradient magnitudes across tasks.

EXAMPLE
# Initialization: weights = tensor([1.0, 1.0])

# After training adjustments:
weights = tensor([0.3, 1.7])  # RUL down, health up

# This means health task needs more focus
# (its gradients were too small relative to RUL)
40Weighted Losses

Multiply each task loss by its learnable weight. This creates the weighted objectives for gradient computation.

EXAMPLE
# Inputs:
losses = [tensor(400.0), tensor(0.36)]  # [RUL, health]
weights = tensor([0.3, 1.7])

# Weighted losses:
weighted_losses = [0.3 * 400.0, 1.7 * 0.36]
                = [120.0, 0.61]

# Total loss for backward pass:
total_loss = 120.0 + 0.61 = 120.61
46Per-Task Gradient Computation

Compute gradients separately for each weighted loss. This requires retain_graph=True since we reuse the computation graph.

EXAMPLE
# For each task, compute gradient w.r.t. shared params
# shared_params: last encoder layer (256 × 128)

# Task 1 (RUL):
grad_rul = autograd.grad(weighted_losses[0], shared_params)
grad_norms[0] = grad_rul.norm()  # e.g., 45.2

# Task 2 (Health):
grad_health = autograd.grad(weighted_losses[1], shared_params)
grad_norms[1] = grad_health.norm()  # e.g., 2.3
52Stack and Average Gradient Norms

Combine gradient norms into tensor and compute their average. This average is the target for balancing.

EXAMPLE
# Gradient norms from each task:
grad_norms = torch.stack([tensor(45.2), tensor(2.3)])
           = tensor([45.2, 2.3])

# Average gradient norm (target baseline):
avg_grad_norm = (45.2 + 2.3) / 2 = 23.75

# Goal: Make both task gradients close to 23.75
56Training Rate Computation

Ratio of current loss to initial loss shows how fast each task is learning. Slower tasks get boosted.

EXAMPLE
# Initial losses (epoch 0):
initial_losses = [2000.0, 1.5]

# Current losses (epoch 50):
losses = [500.0, 1.2]

# Training rates (lower = faster learning):
training_rates = [500/2000, 1.2/1.5]
               = [0.25, 0.80]

# RUL learning 4× faster than health!
65Target Gradient Norms

Slow-learning tasks get higher target gradients. Alpha controls sensitivity to rate differences.

EXAMPLE
# With α = 1.5:
inv_rates = [0.25, 0.80] / mean([0.25, 0.80])
          = [0.48, 1.52]  # Normalized

target_norms = 23.75 * inv_rates^1.5
             = 23.75 * [0.33, 1.87]
             = [7.8, 44.4]

# Health target (44.4) >> RUL target (7.8)
# Forces increased weight on health task
68GradNorm Loss

Minimize difference between actual and target gradient norms. This loss updates only the task weights.

EXAMPLE
# Actual vs Target gradient norms:
grad_norms = [45.2, 2.3]
target_norms = [7.8, 44.4]

# GradNorm loss:
grad_loss = |45.2 - 7.8| + |2.3 - 44.4|
          = 37.4 + 42.1
          = 79.5

# High loss → weights need significant adjustment
70 lines without explanation
1class GradNormLoss(nn.Module):
2    """
3    GradNorm: Gradient Normalization for Adaptive Loss Balancing
4    Chen et al. (2018)
5    """
6
7    def __init__(
8        self,
9        num_tasks: int = 2,
10        alpha: float = 1.5,
11        lr_weights: float = 0.025
12    ):
13        super().__init__()
14        # Learnable task weights (initialized to 1)
15        self.weights = nn.Parameter(torch.ones(num_tasks))
16        self.alpha = alpha
17        self.lr_weights = lr_weights
18
19        # Track initial losses for training rate
20        self.initial_losses = None
21        self.num_tasks = num_tasks
22
23    def forward(
24        self,
25        losses: list,
26        shared_params: nn.Parameter
27    ) -> tuple[torch.Tensor, dict]:
28        """
29        Args:
30            losses: List of individual task losses
31            shared_params: Parameters of last shared layer
32
33        Returns:
34            total_loss: Weighted sum of losses
35            info: Dict with gradient norms and weights
36        """
37        # Store initial losses
38        if self.initial_losses is None:
39            self.initial_losses = [l.detach().clone() for l in losses]
40
41        # Weighted loss
42        weighted_losses = [w * l for w, l in zip(self.weights, losses)]
43        total_loss = sum(weighted_losses)
44
45        # Compute gradient norms for each task
46        grad_norms = []
47        for i, wl in enumerate(weighted_losses):
48            grad = torch.autograd.grad(
49                wl, shared_params,
50                retain_graph=True, create_graph=True
51            )[0]
52            grad_norms.append(grad.norm())
53
54        grad_norms = torch.stack(grad_norms)
55        avg_grad_norm = grad_norms.mean()
56
57        # Compute relative training rates
58        with torch.no_grad():
59            training_rates = torch.tensor([
60                losses[i].item() / self.initial_losses[i].item()
61                for i in range(self.num_tasks)
62            ])
63            # Inverse rate (slower → higher target)
64            inv_rates = training_rates / training_rates.mean()
65
66        # Target gradient norms
67        target_norms = avg_grad_norm * (inv_rates ** self.alpha)
68
69        # GradNorm loss
70        grad_loss = (grad_norms - target_norms).abs().sum()
71
72        return total_loss, {
73            'grad_norms': grad_norms.detach(),
74            'target_norms': target_norms,
75            'weights': self.weights.detach(),
76            'grad_loss': grad_loss
77        }

Training Rate Balancing

GradNorm goes beyond gradient norm equalization—it balances training rates.

The Training Rate Concept

The training rate measures how fast each task is learning:

ri(t)=Li(t)Li(0)r_i(t) = \frac{\mathcal{L}_i(t)}{\mathcal{L}_i(0)}

A task with ri=0.1r_i = 0.1 has reduced its loss to 10% of the initial value—it is learning quickly. A task with ri=0.8r_i = 0.8 has only improved 20%—it is learning slowly.

Balancing Mechanism

The target gradient norm includes a training rate term:

Gitarget=Gˉ×(rirˉ)αG_i^{\text{target}} = \bar{G} \times \left(\frac{r_i}{\bar{r}}\right)^\alpha
  • If task i is learning slowly (r_i > \bar{r}), target gradient is higher → weight increases
  • If task i is learning quickly (r_i < \bar{r}), target gradient is lower → weight decreases

Limitations for RUL

GradNorm has theoretical appeal but practical issues for RUL prediction.

Computational Overhead

AspectStandard MTLGradNorm
Forward passes11
Backward passes1T+1 (one per task + main)
MemoryO(params)O(T × params)
Wall time~3× slower

Gradient Computation Cost

GradNorm requires computing per-task gradients separately, which means T backward passes instead of 1. For our 2-task setting, this triples training time. The cost grows linearly with the number of tasks.

RUL-Specific Issues

  • Noisy training rates: RUL loss fluctuates significantly, making r_i unstable
  • Non-monotonic loss: RUL loss may increase during training (overfitting), confusing training rate computation
  • Last layer only: GradNorm typically only considers gradients at the last shared layer, missing dynamics in earlier layers
  • Hyperparameter sensitivity: α and learning rate for weights require tuning

Empirical Results

On C-MAPSS datasets:

MethodFD001 RMSEFD002 RMSETraining Time
Fixed weights11.816.21.0×
Uncertainty weighting12.417.11.0×
GradNorm11.515.82.8×
AMNL10.813.91.0×

GradNorm shows modest improvement over simpler methods but at significant computational cost. AMNL achieves better results with no additional overhead.


Summary

In this section, we examined GradNorm:

  1. Core insight: Balance gradient magnitudes, not loss values
  2. Algorithm: Minimize difference between actual and target gradient norms
  3. Training rates: Slow-learning tasks get higher target gradients
  4. Overhead: ~3× training time for 2 tasks
  5. RUL limitations: Noisy training rates, high cost
AspectValue
Learnable parametersT weights
Key hyperparameterα (typically 1.5)
Computational overhead~T× more backward passes
Balancing mechanismGradient norm + training rate
RUL suitabilityModerate (high cost, noisy dynamics)
Looking Ahead: GradNorm is computationally expensive and sensitive to training noise. The next section examines Dynamic Weight Average (DWA)—a simpler approach that adjusts weights based on loss change rates without computing per-task gradients.

Having seen gradient-based balancing, we now examine loss-rate-based methods.