Chapter 12
15 min read
Section 59 of 104

Learning Rate Warmup

Optimization Strategy

Learning Objectives

By the end of this section, you will:

  1. Understand why warmup prevents early training failure
  2. Compare linear and exponential warmup strategies
  3. Choose appropriate warmup duration for your model
  4. Implement warmup schedulers in PyTorch
  5. Diagnose warmup-related training issues
Why This Matters: Large learning rates at the start of training can cause gradient explosion, NaN losses, or convergence to poor local minima. Warmup starts with a tiny learning rate and gradually increases it, giving the optimizer time to estimate gradient statistics before taking large steps.

Why Warmup is Necessary

Several factors make early training particularly fragile.

The Cold Start Problem

At initialization, Adam's moment estimates are unreliable:

  • First moment (m): Initialized to 0, takes ~10 steps to stabilize
  • Second moment (v): Initialized to 0, takes ~100 steps to stabilize
  • Bias correction: Helps but doesn't fully compensate

Large learning rates with inaccurate moment estimates cause erratic updates.

Gradient Magnitude at Initialization

When Warmup is Most Important

ScenarioWarmup NeedReason
Large batch sizeHighGradient variance lower, can overshoot
Deep networksHighGradient flow unstable early
Transformers/AttentionCriticalAttention weights highly sensitive
Small batch sizeModerateHigh variance provides implicit regularization
Transfer learningLowWeights already reasonable

AMNL Uses Warmup

Our CNN-BiLSTM-Attention model includes attention layers, making warmup important. We use 5-10 epochs of warmup for stable training.


Warmup Strategies

There are several ways to increase the learning rate during warmup.

Linear Warmup

The most common approach—linearly increase from 0 to target:

ηt=ηtargettTwarmup,tTwarmup\eta_t = \eta_{\text{target}} \cdot \frac{t}{T_{\text{warmup}}}, \quad t \leq T_{\text{warmup}}

Where:

  • ηt\eta_t: Learning rate at step t
  • ηtarget\eta_{\text{target}}: Target learning rate (e.g., 1e-3)
  • TwarmupT_{\text{warmup}}: Warmup duration (steps or epochs)
📝text
1Linear Warmup Schedule:
2
3LR
4  η ─┤                    ────────────
5    │                  ╱
6    │                ╱
7    │              ╱
8    │            ╱
9    │          ╱
10    │        ╱
11    │      ╱
12    │    ╱
13  0 ─┼──╱───────────────────────────────
14    └──┬────────┬────────────────────
15       0    T_warmup            Epochs
16
17Linear: η(t) = η_target × t / T_warmup

Exponential Warmup

Start very small and grow exponentially:

ηt=ηtarget(ηstartηtarget)1t/Twarmup\eta_t = \eta_{\text{target}} \cdot \left(\frac{\eta_{\text{start}}}{\eta_{\text{target}}}\right)^{1 - t/T_{\text{warmup}}}

Exponential warmup spends more time at low learning rates, which can be beneficial for very sensitive models.

Start from a small non-zero value:

ηt=ηstart+(ηtargetηstart)tTwarmup\eta_t = \eta_{\text{start}} + (\eta_{\text{target}} - \eta_{\text{start}}) \cdot \frac{t}{T_{\text{warmup}}}

With ηstart=ηtarget/10\eta_{\text{start}} = \eta_{\text{target}} / 10, this provides a gentler start than pure linear warmup.

Comparison

StrategyFormulaBest For
Linearη × t / TMost cases (recommended)
Exponentialη × (η₀/η)^(1-t/T)Very sensitive models
Gradualη₀ + (η - η₀) × t / TBalanced approach
Constant then switchη₀ then ηSimple, less smooth

Warmup Duration Selection

How long should warmup last?

Rules of Thumb

Training LengthWarmup DurationWarmup %
100 epochs5-10 epochs5-10%
50 epochs3-5 epochs6-10%
200 epochs10-20 epochs5-10%

Batch-Based vs. Epoch-Based

Warmup can be specified in steps (batches) or epochs:

MethodAdvantageDisadvantage
StepsConsistent across batch sizesMust recalculate for different datasets
EpochsIntuitive, dataset-independentVaries with batch size

Epoch-Based for Simplicity

We recommend epoch-based warmup for its simplicity. 5-10% of total epochs is a good starting point. For 100 epochs, use 5-10 epochs of warmup.

AMNL Warmup Configuration

ParameterValue
Total epochs100
Warmup epochs5
Warmup percentage5%
Warmup strategyLinear
Start LR1e-7 (effectively 0)
Target LR1e-3

Implementation

Our research implementation uses a simple but effective warmup approach that integrates cleanly with the training loop.

AMNL Research Implementation

Learning Rate Warmup Factor
🐍enhanced_train_nasa_cmapss_sota_v7.py
1Function Signature

Returns a multiplicative factor [0.1, 1.0] that scales the base learning rate based on the current epoch.

3V5 Success

This warmup strategy was introduced in V5 of our training approach and significantly improved early training stability. We kept it in V7.

5Key Insight

We start at 10% of target LR (not 0%) to ensure some learning happens from epoch 0. Pure zero-start can waste early epochs.

EXAMPLE
Epoch 0: factor=0.1, Epoch 5: factor=0.55, Epoch 10: factor=1.0
7Warmup Check

Only apply warmup during the first warmup_epochs. After that, use full learning rate.

9Linear Ramp

Linear interpolation from 0.1 to 1.0. At epoch 0: 0.1, at epoch warmup_epochs: 1.0.

EXAMPLE
0.1 + 0.9 * (5/10) = 0.1 + 0.45 = 0.55
10Post-Warmup

After warmup completes, return 1.0 (full learning rate). The scheduler then controls LR decay.

4 lines without explanation
1def get_lr_warmup_factor(epoch: int, warmup_epochs: int = 10) -> float:
2    """
3    Learning rate warmup (from V5 - it worked!)
4
5    Linearly ramp up learning rate from 0.1x to 1.0x over warmup_epochs.
6    """
7    if epoch < warmup_epochs:
8        # Linear warmup: 0.1 → 1.0 over warmup_epochs
9        return 0.1 + 0.9 * (epoch / warmup_epochs)
10    return 1.0

Integration with Training Loop

Warmup in Training Loop
🐍enhanced_train_nasa_cmapss_sota_v7.py
3Compute Warmup Factor

Get the multiplicative factor for current epoch. During warmup, this is [0.1, 1.0].

4Warmup Phase Check

Only modify learning rate during warmup epochs. After warmup, let ReduceLROnPlateau control it.

6Apply Warmup LR

Multiply base learning rate by warmup factor. This overrides the optimizer's current LR.

15Scheduler After Warmup

ReduceLROnPlateau only starts stepping after warmup completes. This prevents interference between warmup and plateau detection.

13 lines without explanation
1# Training loop with warmup
2for epoch in range(epochs):
3    # Apply learning rate warmup
4    warmup_factor = get_lr_warmup_factor(epoch, warmup_epochs=10)
5    if epoch < 10:
6        for param_group in optimizer.param_groups:
7            param_group['lr'] = learning_rate * warmup_factor
8
9    # Training phase
10    model.train()
11    for batch_idx, (sequences, targets) in enumerate(train_loader):
12        # ... training code ...
13        pass
14
15    # ReduceLROnPlateau scheduler step (only after warmup)
16    if epoch >= 10:
17        scheduler.step(rmse_last)

Warmup + ReduceLROnPlateau

The key insight is that ReduceLROnPlateau should only start monitoring after warmup completes. During warmup, learning rate changes are expected and should not trigger plateau detection.

Combined Schedule Visualization

🐍python
1# Visualize the combined warmup + plateau schedule
2import matplotlib.pyplot as plt
3
4epochs = range(100)
5lrs = []
6base_lr = 1e-3
7warmup_epochs = 10
8
9# Simulate warmup phase
10for e in epochs:
11    if e < warmup_epochs:
12        factor = 0.1 + 0.9 * (e / warmup_epochs)
13        lrs.append(base_lr * factor)
14    else:
15        # After warmup, assume some plateau reductions
16        if e < 50:
17            lrs.append(base_lr)
18        elif e < 70:
19            lrs.append(base_lr * 0.5)  # First reduction
20        else:
21            lrs.append(base_lr * 0.25)  # Second reduction
22
23plt.figure(figsize=(10, 5))
24plt.plot(epochs, lrs, 'b-', linewidth=2)
25plt.axvline(x=10, color='red', linestyle='--', alpha=0.5, label='Warmup ends')
26plt.xlabel('Epoch')
27plt.ylabel('Learning Rate')
28plt.title('Warmup + ReduceLROnPlateau Schedule')
29plt.legend()
30plt.grid(True, alpha=0.3)
31plt.show()

Summary

In this section, we covered learning rate warmup:

  1. Why needed: Adam's moments need time to stabilize
  2. Linear warmup: ηt=ηt/T\eta_t = \eta \cdot t / T(recommended)
  3. Duration: 5-10% of total epochs
  4. Combined scheduler: Warmup + cosine decay
  5. Step-level updates: Call scheduler.step() after each batch
ParameterValue
Warmup strategyLinear
Warmup epochs5 (of 100)
Start LR~0 (1e-7)
Target LR1e-3
Post-warmupCosine decay to 1e-6
Looking Ahead: After warmup, we need a strategy for the main training phase. The next section covers cosine annealing with warm restarts—a schedule that can escape local minima through periodic learning rate increases.

With warmup understood, we explore cosine annealing schedules.