Learning Objectives
By the end of this section, you will:
- Understand gradient-based task balancing as an alternative to loss weighting
- Learn the GradNorm algorithm for equalizing gradient norms
- Implement training rate balancing to ensure equal learning progress
- Identify computational overhead of gradient-based methods
- Recognize limitations when applied to RUL prediction
Why This Matters: GradNorm (Chen et al., 2018) introduced a key insight: what matters for multi-task learning is not loss magnitude but gradient magnitude. By directly balancing gradients, we ensure each task contributes equally to parameter updates, regardless of loss scale.
Gradient-Based Motivation
The fundamental insight is that optimization happens through gradients, not losses.
Why Gradients Matter More Than Losses
Parameters are updated via gradient descent:
If one task has gradients 100× larger than another, it receives 100× more influence on the update—regardless of how we weight the losses.
Gradient Imbalance Example
GradNorm Algorithm
GradNorm adjusts task weights to equalize the gradient norms across tasks.
Core Idea
Define task weights such that:
Where is the average gradient norm across all tasks. This ensures each task contributes equally to parameter updates.
The GradNorm Objective
GradNorm learns weights by minimizing:
Where:
- : Weighted gradient norm for task i
- : Average gradient norm
- : Relative inverse training rate
- : Hyperparameter controlling training rate balance (typically 1.5)
Algorithm Steps
1GradNorm Algorithm:
2
3For each training step:
41. Compute task losses: L_1, L_2, ..., L_T
52. Compute weighted loss: L = Σ w_i L_i
63. Backward pass: compute gradients ∇_θ L
7
84. For gradient normalization update:
9 a. Compute G_i = ||w_i ∇_W L_i|| for each task
10 (W is the last shared layer weights)
11 b. Compute average: G̅ = mean(G_i)
12 c. Compute training rates: r_i = L̃_i(t) / L̃_i(0)
13 d. Compute target gradients: G_target_i = G̅ × (r_i)^α
14 e. Compute GradNorm loss: L_grad = Σ |G_i - G_target_i|
15 f. Update weights: w_i ← w_i - η_w ∇_{w_i} L_grad
16
175. Normalize weights: w_i ← T × w_i / Σ w_jImplementation
Training Rate Balancing
GradNorm goes beyond gradient norm equalization—it balances training rates.
The Training Rate Concept
The training rate measures how fast each task is learning:
A task with has reduced its loss to 10% of the initial value—it is learning quickly. A task with has only improved 20%—it is learning slowly.
Balancing Mechanism
The target gradient norm includes a training rate term:
- If task i is learning slowly (r_i > \bar{r}), target gradient is higher → weight increases
- If task i is learning quickly (r_i < \bar{r}), target gradient is lower → weight decreases
Limitations for RUL
GradNorm has theoretical appeal but practical issues for RUL prediction.
Computational Overhead
| Aspect | Standard MTL | GradNorm |
|---|---|---|
| Forward passes | 1 | 1 |
| Backward passes | 1 | T+1 (one per task + main) |
| Memory | O(params) | O(T × params) |
| Wall time | 1× | ~3× slower |
Gradient Computation Cost
GradNorm requires computing per-task gradients separately, which means T backward passes instead of 1. For our 2-task setting, this triples training time. The cost grows linearly with the number of tasks.
RUL-Specific Issues
- Noisy training rates: RUL loss fluctuates significantly, making r_i unstable
- Non-monotonic loss: RUL loss may increase during training (overfitting), confusing training rate computation
- Last layer only: GradNorm typically only considers gradients at the last shared layer, missing dynamics in earlier layers
- Hyperparameter sensitivity: α and learning rate for weights require tuning
Empirical Results
On C-MAPSS datasets:
| Method | FD001 RMSE | FD002 RMSE | Training Time |
|---|---|---|---|
| Fixed weights | 11.8 | 16.2 | 1.0× |
| Uncertainty weighting | 12.4 | 17.1 | 1.0× |
| GradNorm | 11.5 | 15.8 | 2.8× |
| AMNL | 10.8 | 13.9 | 1.0× |
GradNorm shows modest improvement over simpler methods but at significant computational cost. AMNL achieves better results with no additional overhead.
Summary
In this section, we examined GradNorm:
- Core insight: Balance gradient magnitudes, not loss values
- Algorithm: Minimize difference between actual and target gradient norms
- Training rates: Slow-learning tasks get higher target gradients
- Overhead: ~3× training time for 2 tasks
- RUL limitations: Noisy training rates, high cost
| Aspect | Value |
|---|---|
| Learnable parameters | T weights |
| Key hyperparameter | α (typically 1.5) |
| Computational overhead | ~T× more backward passes |
| Balancing mechanism | Gradient norm + training rate |
| RUL suitability | Moderate (high cost, noisy dynamics) |
Looking Ahead: GradNorm is computationally expensive and sensitive to training noise. The next section examines Dynamic Weight Average (DWA)—a simpler approach that adjusts weights based on loss change rates without computing per-task gradients.
Having seen gradient-based balancing, we now examine loss-rate-based methods.