Introduction
The original Transformer paper introduced a specific learning rate schedule with warmup that is crucial for stable training. This section covers the warmup schedule and other scheduling strategies.
Why Warmup Matters
The Problem with Large Learning Rates
๐text
1Without warmup:
2 Step 1: Random weights, high gradients
3 Large LR โ weights change drastically
4 Model becomes unstable!
5
6 Step 2: Attention scores explode or vanish
7 Training diverges or stalls
8
9With warmup:
10 Steps 1-4000: Gradually increase LR from 0 to max
11 Model finds stable region of loss landscape
12
13 Steps 4000+: Decrease LR for fine-grained optimization
14 Model converges to good minimumVisualization
๐text
1Learning Rate Schedule:
2
3LR
4โ
5โ โฑโฒ
6โ โฑ โฒ
7โ โฑ โฒ
8โ โฑ โฒ
9โ โฑ โฒ__________
10โ โฑ โฒ___
11โ โฑ โฒ____
12โ โฑ โฒ_____
13โโฑโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Steps
14 0 warmup peak decay
15
16 โโโโโโโโโ
17 warmup
18 (4000 steps)Original Transformer Schedule
Formula
This creates:
- Linear warmup for step < warmup_steps
- Inverse square root decay for step > warmup_steps
Implementation
๐python
1import torch
2import torch.optim as optim
3from torch.optim.lr_scheduler import LambdaLR
4import math
5from typing import Optional
6
7
8class TransformerScheduler:
9 """
10 Original Transformer learning rate scheduler.
11
12 lr = d_model^(-0.5) ร min(step^(-0.5), step ร warmup^(-1.5))
13
14 Args:
15 optimizer: Optimizer instance
16 d_model: Model dimension (used for scaling)
17 warmup_steps: Number of warmup steps
18 factor: Additional scaling factor (default: 1.0)
19
20 Example:
21 >>> scheduler = TransformerScheduler(optimizer, d_model=512, warmup_steps=4000)
22 >>> for step in range(100000):
23 ... train_step()
24 ... scheduler.step()
25 """
26
27 def __init__(
28 self,
29 optimizer: optim.Optimizer,
30 d_model: int,
31 warmup_steps: int = 4000,
32 factor: float = 1.0
33 ):
34 self.optimizer = optimizer
35 self.d_model = d_model
36 self.warmup_steps = warmup_steps
37 self.factor = factor
38
39 self._step = 0
40 self._rate = 0
41
42 # Initialize to first step rate
43 self.step()
44
45 def step(self):
46 """Update learning rate."""
47 self._step += 1
48 rate = self._compute_rate()
49 self._rate = rate
50
51 for param_group in self.optimizer.param_groups:
52 param_group['lr'] = rate
53
54 def _compute_rate(self) -> float:
55 """Compute learning rate for current step."""
56 step = self._step
57 warmup = self.warmup_steps
58
59 # Original formula
60 rate = self.factor * (
61 self.d_model ** (-0.5) *
62 min(step ** (-0.5), step * warmup ** (-1.5))
63 )
64
65 return rate
66
67 def get_lr(self) -> float:
68 """Get current learning rate."""
69 return self._rateAlternative Schedulers
Linear Warmup + Cosine Decay
๐python
1class WarmupCosineScheduler:
2 """
3 Linear warmup followed by cosine decay.
4
5 Often used in modern transformer training.
6
7 Args:
8 optimizer: Optimizer instance
9 warmup_steps: Number of warmup steps
10 total_steps: Total training steps
11 min_lr: Minimum learning rate at end
12 max_lr: Maximum learning rate after warmup
13 """
14
15 def __init__(
16 self,
17 optimizer: optim.Optimizer,
18 warmup_steps: int,
19 total_steps: int,
20 max_lr: float = 1e-4,
21 min_lr: float = 1e-6
22 ):
23 self.optimizer = optimizer
24 self.warmup_steps = warmup_steps
25 self.total_steps = total_steps
26 self.max_lr = max_lr
27 self.min_lr = min_lr
28
29 self._step = 0
30 self.step()
31
32 def step(self):
33 """Update learning rate."""
34 self._step += 1
35 lr = self._compute_lr()
36
37 for param_group in self.optimizer.param_groups:
38 param_group['lr'] = lr
39
40 def _compute_lr(self) -> float:
41 """Compute learning rate for current step."""
42 step = self._step
43
44 if step < self.warmup_steps:
45 # Linear warmup
46 return self.max_lr * step / self.warmup_steps
47 else:
48 # Cosine decay
49 progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
50 progress = min(1.0, progress)
51 return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1 + math.cos(math.pi * progress))
52
53 def get_lr(self) -> float:
54 return self._compute_lr()
55
56
57class WarmupLinearScheduler:
58 """
59 Linear warmup followed by linear decay.
60
61 Simple and effective.
62 """
63
64 def __init__(
65 self,
66 optimizer: optim.Optimizer,
67 warmup_steps: int,
68 total_steps: int,
69 max_lr: float = 1e-4,
70 min_lr: float = 0.0
71 ):
72 self.optimizer = optimizer
73 self.warmup_steps = warmup_steps
74 self.total_steps = total_steps
75 self.max_lr = max_lr
76 self.min_lr = min_lr
77
78 self._step = 0
79 self.step()
80
81 def step(self):
82 self._step += 1
83 lr = self._compute_lr()
84
85 for param_group in self.optimizer.param_groups:
86 param_group['lr'] = lr
87
88 def _compute_lr(self) -> float:
89 step = self._step
90
91 if step < self.warmup_steps:
92 return self.max_lr * step / self.warmup_steps
93 else:
94 progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
95 progress = min(1.0, progress)
96 return self.max_lr - (self.max_lr - self.min_lr) * progress
97
98 def get_lr(self) -> float:
99 return self._compute_lr()Warmup Best Practices
Guidelines
๐text
1WARMUP STEPS:
2โโโโโโโโโโโโโ
3
4Rule of thumb: warmup = 1-10% of total training steps
5
6For translation (original paper):
7 Total steps: ~100,000
8 Warmup: 4,000 (4%)
9
10For language models:
11 Total tokens: ~10B
12 Warmup: ~2000 steps or 1% of total
13
14For fine-tuning:
15 Warmup: ~5-10% of total steps
16 Often shorter than pre-training
17
18
19CHOOSING MAX LEARNING RATE:
20โโโโโโโโโโโโโโโโโโโโโโโโโโ
21
22Base rule: Larger models need smaller LR
23
24GPT-style (d_model, LR):
25 256-dim: ~3e-4
26 512-dim: ~1e-4 (original Transformer)
27 1024-dim: ~5e-5
28 2048-dim: ~2e-5
29
30With Adam:
31 ฮฒ1 = 0.9, ฮฒ2 = 0.98 (translation)
32 ฮฒ1 = 0.9, ฮฒ2 = 0.999 (language modeling)
33
34
35SIGNS OF WRONG LR:
36โโโโโโโโโโโโโโโโโ
37
38LR too high:
39 - Loss spikes/explodes
40 - Training unstable
41 - NaN values
42
43LR too low:
44 - Very slow progress
45 - Never reaches good loss
46 - Stuck in local minimum
47
48Warmup too short:
49 - Early instability
50 - Attention patterns don't form properly
51
52Warmup too long:
53 - Slower convergence
54 - May not reach best performanceUsing with PyTorch Schedulers
LambdaLR Integration
๐python
1def get_transformer_scheduler(
2 optimizer: optim.Optimizer,
3 d_model: int,
4 warmup_steps: int,
5 factor: float = 1.0
6) -> LambdaLR:
7 """
8 Create Transformer scheduler using PyTorch's LambdaLR.
9
10 Compatible with PyTorch training loops and checkpointing.
11
12 Args:
13 optimizer: Optimizer instance
14 d_model: Model dimension
15 warmup_steps: Warmup steps
16 factor: Additional scaling
17
18 Returns:
19 LambdaLR scheduler
20 """
21 def lr_lambda(step: int) -> float:
22 # Avoid division by zero at step 0
23 step = max(step, 1)
24 return factor * (
25 d_model ** (-0.5) *
26 min(step ** (-0.5), step * warmup_steps ** (-1.5))
27 )
28
29 return LambdaLR(optimizer, lr_lambda)
30
31
32def get_cosine_scheduler(
33 optimizer: optim.Optimizer,
34 warmup_steps: int,
35 total_steps: int,
36 min_ratio: float = 0.01
37) -> LambdaLR:
38 """
39 Create cosine scheduler with warmup using LambdaLR.
40 """
41 def lr_lambda(step: int) -> float:
42 if step < warmup_steps:
43 return step / warmup_steps
44 else:
45 progress = (step - warmup_steps) / (total_steps - warmup_steps)
46 progress = min(1.0, progress)
47 return min_ratio + 0.5 * (1 - min_ratio) * (1 + math.cos(math.pi * progress))
48
49 return LambdaLR(optimizer, lr_lambda)Example Usage
๐python
1# Setup
2model = Transformer(...)
3optimizer = torch.optim.Adam(
4 model.parameters(),
5 lr=1.0, # Will be overridden by scheduler
6 betas=(0.9, 0.98),
7 eps=1e-9
8)
9
10scheduler = get_transformer_scheduler(
11 optimizer,
12 d_model=512,
13 warmup_steps=4000
14)
15
16# Training loop
17for epoch in range(num_epochs):
18 for batch in train_loader:
19 optimizer.zero_grad()
20
21 logits = model(batch['source'], batch['target'])
22 loss = criterion(logits, batch['target'])
23
24 loss.backward()
25 optimizer.step()
26 scheduler.step() # Update LR after each step!
27
28 # Log learning rate
29 current_lr = optimizer.param_groups[0]['lr']
30 print(f"Epoch {epoch}, LR: {current_lr:.6f}")Gradient Clipping
Complementary to Scheduling
๐text
1While warmup prevents large LR early,
2gradient clipping prevents large gradients always.
3
4TYPES:
5โโโโโโ
6
71. Clip by value:
8 grad = torch.clamp(grad, -max_value, max_value)
9
10 Pros: Simple
11 Cons: Changes gradient direction
12
132. Clip by norm (recommended):
14 total_norm = sqrt(sum(grad**2))
15 if total_norm > max_norm:
16 grad = grad * max_norm / total_norm
17
18 Pros: Preserves gradient direction
19 Cons: Slightly more compute
20
21
22USAGE:
23โโโโโโ
24
25# After loss.backward(), before optimizer.step()
26torch.nn.utils.clip_grad_norm_(
27 model.parameters(),
28 max_norm=1.0
29)
30
31# Or clip by value
32torch.nn.utils.clip_grad_value_(
33 model.parameters(),
34 clip_value=1.0
35)
36
37
38TYPICAL VALUES:
39โโโโโโโโโโโโโโโ
40
41Translation: max_norm = 1.0 - 5.0
42Language modeling: max_norm = 1.0
43Fine-tuning: max_norm = 1.0
44
45Monitor gradient norms during training!
46If frequently clipping, LR may be too high.Summary
Learning Rate Scheduling Key Points
| Schedule | Formula | Use Case |
|---|---|---|
| Transformer | dโปโฐยทโต ร min(sโปโฐยทโต, sรwโปยนยทโต) | Original paper |
| Cosine | Linear warmup + cosine decay | Modern training |
| Linear | Linear warmup + linear decay | Simple alternative |
Best Practices
- Always use warmup (4000 steps for translation)
- Combine with gradient clipping (max_norm=1.0)
- Monitor learning rate and loss closely
- Save scheduler state in checkpoints
Exercises
Implementation
- Implement cyclical learning rate schedule with warmup.
- Add support for learning rate finder (range test).
- Implement layer-wise learning rate decay.
Analysis
- Compare convergence speed of different schedules.
- Find optimal warmup for your dataset size.
In the next section, we'll implement the complete Training Loop with all components integrated: data loading, loss computation, optimization, and logging.