Chapter 10
15 min read
Section 52 of 75

Learning Rate Scheduling

Training Pipeline

Introduction

The original Transformer paper introduced a specific learning rate schedule with warmup that is crucial for stable training. This section covers the warmup schedule and other scheduling strategies.


Why Warmup Matters

The Problem with Large Learning Rates

๐Ÿ“text
1Without warmup:
2  Step 1: Random weights, high gradients
3          Large LR โ†’ weights change drastically
4          Model becomes unstable!
5
6  Step 2: Attention scores explode or vanish
7          Training diverges or stalls
8
9With warmup:
10  Steps 1-4000: Gradually increase LR from 0 to max
11                Model finds stable region of loss landscape
12
13  Steps 4000+: Decrease LR for fine-grained optimization
14               Model converges to good minimum

Visualization

๐Ÿ“text
1Learning Rate Schedule:
2
3LR
4โ”‚
5โ”‚        โ•ฑโ•ฒ
6โ”‚       โ•ฑ  โ•ฒ
7โ”‚      โ•ฑ    โ•ฒ
8โ”‚     โ•ฑ      โ•ฒ
9โ”‚    โ•ฑ        โ•ฒ__________
10โ”‚   โ•ฑ                    โ•ฒ___
11โ”‚  โ•ฑ                          โ•ฒ____
12โ”‚ โ•ฑ                                 โ•ฒ_____
13โ”‚โ•ฑโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ†’ Steps
14 0   warmup      peak     decay
15
16 โ”‚โ†โ”€โ”€โ”€โ”€โ”€โ†’โ”‚
17  warmup
18  (4000 steps)

Original Transformer Schedule

Formula

lr=dmodelโˆ’0.5ร—minโก(stepโˆ’0.5,stepร—warmup_stepsโˆ’1.5)\text{lr} = d_{\text{model}}^{-0.5} \times \min(\text{step}^{-0.5}, \text{step} \times \text{warmup\_steps}^{-1.5})

This creates:

  • Linear warmup for step < warmup_steps
  • Inverse square root decay for step > warmup_steps

Implementation

๐Ÿpython
1import torch
2import torch.optim as optim
3from torch.optim.lr_scheduler import LambdaLR
4import math
5from typing import Optional
6
7
8class TransformerScheduler:
9    """
10    Original Transformer learning rate scheduler.
11
12    lr = d_model^(-0.5) ร— min(step^(-0.5), step ร— warmup^(-1.5))
13
14    Args:
15        optimizer: Optimizer instance
16        d_model: Model dimension (used for scaling)
17        warmup_steps: Number of warmup steps
18        factor: Additional scaling factor (default: 1.0)
19
20    Example:
21        >>> scheduler = TransformerScheduler(optimizer, d_model=512, warmup_steps=4000)
22        >>> for step in range(100000):
23        ...     train_step()
24        ...     scheduler.step()
25    """
26
27    def __init__(
28        self,
29        optimizer: optim.Optimizer,
30        d_model: int,
31        warmup_steps: int = 4000,
32        factor: float = 1.0
33    ):
34        self.optimizer = optimizer
35        self.d_model = d_model
36        self.warmup_steps = warmup_steps
37        self.factor = factor
38
39        self._step = 0
40        self._rate = 0
41
42        # Initialize to first step rate
43        self.step()
44
45    def step(self):
46        """Update learning rate."""
47        self._step += 1
48        rate = self._compute_rate()
49        self._rate = rate
50
51        for param_group in self.optimizer.param_groups:
52            param_group['lr'] = rate
53
54    def _compute_rate(self) -> float:
55        """Compute learning rate for current step."""
56        step = self._step
57        warmup = self.warmup_steps
58
59        # Original formula
60        rate = self.factor * (
61            self.d_model ** (-0.5) *
62            min(step ** (-0.5), step * warmup ** (-1.5))
63        )
64
65        return rate
66
67    def get_lr(self) -> float:
68        """Get current learning rate."""
69        return self._rate

Alternative Schedulers

Linear Warmup + Cosine Decay

๐Ÿpython
1class WarmupCosineScheduler:
2    """
3    Linear warmup followed by cosine decay.
4
5    Often used in modern transformer training.
6
7    Args:
8        optimizer: Optimizer instance
9        warmup_steps: Number of warmup steps
10        total_steps: Total training steps
11        min_lr: Minimum learning rate at end
12        max_lr: Maximum learning rate after warmup
13    """
14
15    def __init__(
16        self,
17        optimizer: optim.Optimizer,
18        warmup_steps: int,
19        total_steps: int,
20        max_lr: float = 1e-4,
21        min_lr: float = 1e-6
22    ):
23        self.optimizer = optimizer
24        self.warmup_steps = warmup_steps
25        self.total_steps = total_steps
26        self.max_lr = max_lr
27        self.min_lr = min_lr
28
29        self._step = 0
30        self.step()
31
32    def step(self):
33        """Update learning rate."""
34        self._step += 1
35        lr = self._compute_lr()
36
37        for param_group in self.optimizer.param_groups:
38            param_group['lr'] = lr
39
40    def _compute_lr(self) -> float:
41        """Compute learning rate for current step."""
42        step = self._step
43
44        if step < self.warmup_steps:
45            # Linear warmup
46            return self.max_lr * step / self.warmup_steps
47        else:
48            # Cosine decay
49            progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
50            progress = min(1.0, progress)
51            return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
52
53    def get_lr(self) -> float:
54        return self._compute_lr()
55
56
57class WarmupLinearScheduler:
58    """
59    Linear warmup followed by linear decay.
60
61    Simple and effective.
62    """
63
64    def __init__(
65        self,
66        optimizer: optim.Optimizer,
67        warmup_steps: int,
68        total_steps: int,
69        max_lr: float = 1e-4,
70        min_lr: float = 0.0
71    ):
72        self.optimizer = optimizer
73        self.warmup_steps = warmup_steps
74        self.total_steps = total_steps
75        self.max_lr = max_lr
76        self.min_lr = min_lr
77
78        self._step = 0
79        self.step()
80
81    def step(self):
82        self._step += 1
83        lr = self._compute_lr()
84
85        for param_group in self.optimizer.param_groups:
86            param_group['lr'] = lr
87
88    def _compute_lr(self) -> float:
89        step = self._step
90
91        if step < self.warmup_steps:
92            return self.max_lr * step / self.warmup_steps
93        else:
94            progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
95            progress = min(1.0, progress)
96            return self.max_lr - (self.max_lr - self.min_lr) * progress
97
98    def get_lr(self) -> float:
99        return self._compute_lr()

Warmup Best Practices

Guidelines

๐Ÿ“text
1WARMUP STEPS:
2โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
3
4Rule of thumb: warmup = 1-10% of total training steps
5
6For translation (original paper):
7  Total steps: ~100,000
8  Warmup: 4,000 (4%)
9
10For language models:
11  Total tokens: ~10B
12  Warmup: ~2000 steps or 1% of total
13
14For fine-tuning:
15  Warmup: ~5-10% of total steps
16  Often shorter than pre-training
17
18
19CHOOSING MAX LEARNING RATE:
20โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
21
22Base rule: Larger models need smaller LR
23
24GPT-style (d_model, LR):
25  256-dim:  ~3e-4
26  512-dim:  ~1e-4 (original Transformer)
27  1024-dim: ~5e-5
28  2048-dim: ~2e-5
29
30With Adam:
31  ฮฒ1 = 0.9, ฮฒ2 = 0.98 (translation)
32  ฮฒ1 = 0.9, ฮฒ2 = 0.999 (language modeling)
33
34
35SIGNS OF WRONG LR:
36โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
37
38LR too high:
39  - Loss spikes/explodes
40  - Training unstable
41  - NaN values
42
43LR too low:
44  - Very slow progress
45  - Never reaches good loss
46  - Stuck in local minimum
47
48Warmup too short:
49  - Early instability
50  - Attention patterns don't form properly
51
52Warmup too long:
53  - Slower convergence
54  - May not reach best performance

Using with PyTorch Schedulers

LambdaLR Integration

๐Ÿpython
1def get_transformer_scheduler(
2    optimizer: optim.Optimizer,
3    d_model: int,
4    warmup_steps: int,
5    factor: float = 1.0
6) -> LambdaLR:
7    """
8    Create Transformer scheduler using PyTorch's LambdaLR.
9
10    Compatible with PyTorch training loops and checkpointing.
11
12    Args:
13        optimizer: Optimizer instance
14        d_model: Model dimension
15        warmup_steps: Warmup steps
16        factor: Additional scaling
17
18    Returns:
19        LambdaLR scheduler
20    """
21    def lr_lambda(step: int) -> float:
22        # Avoid division by zero at step 0
23        step = max(step, 1)
24        return factor * (
25            d_model ** (-0.5) *
26            min(step ** (-0.5), step * warmup_steps ** (-1.5))
27        )
28
29    return LambdaLR(optimizer, lr_lambda)
30
31
32def get_cosine_scheduler(
33    optimizer: optim.Optimizer,
34    warmup_steps: int,
35    total_steps: int,
36    min_ratio: float = 0.01
37) -> LambdaLR:
38    """
39    Create cosine scheduler with warmup using LambdaLR.
40    """
41    def lr_lambda(step: int) -> float:
42        if step < warmup_steps:
43            return step / warmup_steps
44        else:
45            progress = (step - warmup_steps) / (total_steps - warmup_steps)
46            progress = min(1.0, progress)
47            return min_ratio + 0.5 * (1 - min_ratio) * (1 + math.cos(math.pi * progress))
48
49    return LambdaLR(optimizer, lr_lambda)

Example Usage

๐Ÿpython
1# Setup
2model = Transformer(...)
3optimizer = torch.optim.Adam(
4    model.parameters(),
5    lr=1.0,  # Will be overridden by scheduler
6    betas=(0.9, 0.98),
7    eps=1e-9
8)
9
10scheduler = get_transformer_scheduler(
11    optimizer,
12    d_model=512,
13    warmup_steps=4000
14)
15
16# Training loop
17for epoch in range(num_epochs):
18    for batch in train_loader:
19        optimizer.zero_grad()
20
21        logits = model(batch['source'], batch['target'])
22        loss = criterion(logits, batch['target'])
23
24        loss.backward()
25        optimizer.step()
26        scheduler.step()  # Update LR after each step!
27
28    # Log learning rate
29    current_lr = optimizer.param_groups[0]['lr']
30    print(f"Epoch {epoch}, LR: {current_lr:.6f}")

Gradient Clipping

Complementary to Scheduling

๐Ÿ“text
1While warmup prevents large LR early,
2gradient clipping prevents large gradients always.
3
4TYPES:
5โ”€โ”€โ”€โ”€โ”€โ”€
6
71. Clip by value:
8   grad = torch.clamp(grad, -max_value, max_value)
9
10   Pros: Simple
11   Cons: Changes gradient direction
12
132. Clip by norm (recommended):
14   total_norm = sqrt(sum(grad**2))
15   if total_norm > max_norm:
16       grad = grad * max_norm / total_norm
17
18   Pros: Preserves gradient direction
19   Cons: Slightly more compute
20
21
22USAGE:
23โ”€โ”€โ”€โ”€โ”€โ”€
24
25# After loss.backward(), before optimizer.step()
26torch.nn.utils.clip_grad_norm_(
27    model.parameters(),
28    max_norm=1.0
29)
30
31# Or clip by value
32torch.nn.utils.clip_grad_value_(
33    model.parameters(),
34    clip_value=1.0
35)
36
37
38TYPICAL VALUES:
39โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
40
41Translation: max_norm = 1.0 - 5.0
42Language modeling: max_norm = 1.0
43Fine-tuning: max_norm = 1.0
44
45Monitor gradient norms during training!
46If frequently clipping, LR may be too high.

Summary

Learning Rate Scheduling Key Points

ScheduleFormulaUse Case
Transformerdโปโฐยทโต ร— min(sโปโฐยทโต, sร—wโปยนยทโต)Original paper
CosineLinear warmup + cosine decayModern training
LinearLinear warmup + linear decaySimple alternative

Best Practices

  • Always use warmup (4000 steps for translation)
  • Combine with gradient clipping (max_norm=1.0)
  • Monitor learning rate and loss closely
  • Save scheduler state in checkpoints

Exercises

Implementation

  • Implement cyclical learning rate schedule with warmup.
  • Add support for learning rate finder (range test).
  • Implement layer-wise learning rate decay.

Analysis

  • Compare convergence speed of different schedules.
  • Find optimal warmup for your dataset size.

In the next section, we'll implement the complete Training Loop with all components integrated: data loading, loss computation, optimization, and logging.

Loading comments...