Chapter 9
22 min read
Section 28 of 65

SGD and Momentum

Optimizers

Learning Objectives

By the end of this section, you will be able to:

  1. Explain the difference between batch, mini-batch, and stochastic gradient descent and why mini-batch SGD is the standard in practice
  2. Identify the ravine problem where vanilla SGD oscillates in high-curvature directions and crawls in low-curvature directions
  3. Derive the momentum update rule and explain how exponential averaging of gradients accelerates convergence
  4. Implement SGD with momentum from scratch in NumPy and compare it against vanilla SGD
  5. Use PyTorch's SGD optimizer with the momentum\texttt{momentum} parameter to train a real network
  6. Describe Nesterov accelerated gradient and when it helps

Where We Left Off

In Chapter 8, we derived backpropagation and used it to compute gradients for all 31 parameters of our 2×2 diagonal flip network. Then we applied the simplest possible update rule:

wnew=woldηLww_{\text{new}} = w_{\text{old}} - \eta \cdot \frac{\partial L}{\partial w}

After one step with η=0.1\eta = 0.1, the loss dropped from 0.896 to about 0.63 — a 30% improvement. That felt good. But here's the question that Chapter 8 left unanswered: is this the best we can do?

The answer is no. The simple update rule we used — called vanilla gradient descent — has fundamental problems that get worse as networks get larger. This chapter introduces optimizers: smarter update rules that converge faster, more reliably, and with less tuning.

What changes in this chapter: In Chapter 8, we focused on computing gradients (the hard part). Now we take those gradients as given and focus on using them more intelligently. The gradient tells us which direction is downhill — the optimizer decides how far and how fast to step.

Stochastic Gradient Descent

Before improving the update rule, we need to address a practical issue: how many training examples do we use per gradient step?

Three Flavors of Gradient Descent

MethodExamples per stepGradient qualitySpeed
Batch GDAll NPerfect (true gradient)Slow — one step per epoch
Stochastic GD (SGD)1Very noisyFast — N steps per epoch
Mini-batch SGDB (e.g. 32, 64)Moderate noiseBest tradeoff

Batch gradient descent computes the gradient using all training examples, then takes one step. The gradient is exact but you only get one update per pass through the data. For our 16-image dataset, that's one step per epoch.

Stochastic gradient descent (SGD) computes the gradient from a single random example and updates immediately. Each gradient is noisy — it may point in a slightly wrong direction — but you get 16 updates per epoch instead of 1. In practice, the noise can actually help escape local minima.

Mini-batch SGD splits the difference: use BB examples per step (typically 32–512). The gradient is smoother than single-sample SGD but you still get multiple updates per epoch. This is what nearly all modern training uses.

In practice, "SGD" almost always means mini-batch SGD. When deep learning papers say "we trained with SGD," they mean mini-batch SGD with a batch size of 32–512, not single-sample updates. The term is used loosely but the meaning is consistent.

The Noise Trade-Off

Mini-batch gradients are noisy estimates of the true gradient. Consider our 16-image dataset. The true gradient (averaged over all 16 images) points directly toward the minimum. But the gradient from image #3 alone might point slightly off to the side, because image #3 has different features than the average.

This noise has consequences. Instead of walking smoothly downhill, SGD takes a random walk with a downhill drift — zig-zagging toward the minimum rather than marching straight there. The drift gets you there eventually, but the zig-zags waste steps.

Noise is a feature, not a bug. The randomness helps SGD escape shallow local minima and saddle points that batch gradient descent would get stuck in. Research shows that the noise from mini-batches acts as implicit regularization, improving generalization. The goal isn't to eliminate noise — it's to keep the useful noise while reducing the harmful oscillations.

The Ravine Problem

The deepest problem with vanilla SGD isn't noise — it's geometry. Real loss surfaces have different curvatures in different directions. Imagine a long, narrow valley (a ravine):

  • Along the valley floor (the long axis), the surface slopes gently. The gradient is small. Progress is slow.
  • Across the valley walls (the short axis), the surface is steep. The gradient is large. Updates overshoot and bounce back and forth.

Mathematically, consider the loss function L(w1,w2)=w12+10w22L(w_1, w_2) = w_1^2 + 10 \cdot w_2^2. The curvature in the w2w_2 direction (second derivative = 20) is ten times larger than in the w1w_1 direction (second derivative = 2). The contour lines form elongated ellipses — a ravine.

The Learning Rate Dilemma

This creates an impossible dilemma for vanilla SGD:

Learning ratew₁ direction (gentle)w₂ direction (steep)
Large (0.04)Good progressOscillates wildly or diverges
Small (0.005)Barely movesConverges but wastes the opportunity
Medium (0.015)Slow progressMild oscillation

There is no single learning rate that works well for both directions simultaneously. You are forced to choose a conservative (small) learning rate to prevent the steep direction from diverging, which makes the gentle direction painfully slow.

The ravine problem in one sentence: vanilla SGD can only take one-size-fits-all steps, but different directions need different step sizes. This mismatch is the #1 reason vanilla SGD is slow on real loss surfaces, where curvature ratios of 100:1 or 1000:1 are common.

This is where momentum comes in. It solves both problems — noise and ravines — with a single, elegant idea.


Momentum: The Physics Intuition

Imagine a heavy ball rolling down our ravine. Unlike vanilla SGD (which teleports to a new position each step), the ball has inertia. It keeps rolling in the direction it was already moving. This gives it two superpowers:

  1. Acceleration along the valley floor: the gradient consistently points along the valley. Each step adds more velocity in the same direction. The ball speeds up, taking ever-larger steps — exactly what we want.
  2. Dampening across the walls: the gradient alternates direction (left wall, right wall, left wall...). The velocity in this direction oscillates too, but each oscillation partially cancels the previous one. The ball averages out the bouncing and rolls smoothly along the floor.

The same logic applies to noisy gradients. If noise pushes the ball left on step 5 and right on step 6, the velocity absorbs both pushes and barely changes direction. The signal (consistent downhill) accumulates while the noise (random zig-zags) cancels.

The key insight: vanilla SGD treats every gradient as equally important. Momentum says: "If the last 10 gradients all pointed left, I should trust that direction and take a bigger step." This is the exponential moving average of gradients — a simple running average that weights recent gradients more than old ones.

The Mathematics of Momentum

The momentum update introduces a velocity vector v\mathbf{v} that accumulates past gradients. At each step tt:

Step 1: Update the velocity

vt=βvt1+L(wt)\mathbf{v}_t = \beta \cdot \mathbf{v}_{t-1} + \nabla L(\mathbf{w}_t)

Step 2: Update the weights

wt+1=wtηvt\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \cdot \mathbf{v}_t

Here β\beta is the momentum coefficient (typically 0.9) and η\eta is the learning rate. When β=0\beta = 0, this reduces to vanilla SGD.

Unrolling the Velocity

Let's expand vt\mathbf{v}_t to see what it really computes:

vt=Lt+βLt1+β2Lt2+β3Lt3+\mathbf{v}_t = \nabla L_t + \beta \cdot \nabla L_{t-1} + \beta^2 \cdot \nabla L_{t-2} + \beta^3 \cdot \nabla L_{t-3} + \cdots

This is an exponential moving average of gradients. The current gradient has weight 1, the previous gradient has weight β=0.9\beta = 0.9, two steps ago has weight β2=0.81\beta^2 = 0.81, and so on. Gradients from the distant past fade away exponentially.

The Effective Window

The sum of the weights is 1+β+β2+=11β1 + \beta + \beta^2 + \cdots = \frac{1}{1 - \beta}. For β=0.9\beta = 0.9, this is 10.1=10\frac{1}{0.1} = 10. So the velocity effectively averages the last ~10 gradients, weighted toward the most recent ones.

βEffective windowBehavior
0.01 gradientVanilla SGD — no memory
0.5~2 gradientsMild smoothing
0.9~10 gradientsStandard momentum — good balance
0.95~20 gradientsHeavy momentum — very smooth
0.99~100 gradientsExtreme momentum — slow to turn

Why This Solves the Ravine Problem

Consider what happens in each direction of our ravine:

Along the valley floor (consistent direction): every gradient points roughly the same way. The velocity accumulates: vg+0.9g+0.81g+10gv \approx g + 0.9g + 0.81g + \cdots \approx 10g. The effective step is 10× larger than vanilla SGD. We race along the valley floor.

Across the valley walls (oscillating direction): the gradient alternates sign. The velocity alternates too: vg0.9g+0.81gv \approx g - 0.9g + 0.81g - \cdots. The terms partially cancel, giving vg11+β0.53gv \approx g \cdot \frac{1}{1+\beta} \approx 0.53g. The effective step is smaller than vanilla SGD. The oscillation is dampened.

Momentum is adaptive without being per-parameter. A single β\beta value automatically adjusts the effective step size in each direction based on gradient consistency. Consistent directions get amplified; oscillating directions get dampened. No extra hyperparameters needed.

Interactive: SGD vs Momentum

The visualization below shows vanilla SGD (orange) and SGD with momentum (blue) optimizing the same loss surface. Both start at the green dot (5, 2) and aim for the white circle at the origin (0, 0).

Loading 3D visualization...

Try these experiments:

  1. Default settings (lr=0.015, β=0.9, noise=0): Press play. Watch momentum race ahead while SGD barely moves in w1w_1.
  2. Increase learning rate to ~0.04: SGD starts oscillating in w2w_2. Momentum stays smooth.
  3. Add noise (noise=5): SGD path becomes jagged. Momentum path stays smoother because it averages out the noise.
  4. Set β=0: The blue path matches orange exactly — confirming that β=0\beta = 0 is vanilla SGD.
  5. Set β=0.99: Very heavy momentum. The ball overshoots initially but then corrects. Too much momentum is slow to change direction.

NumPy: SGD vs Momentum from Scratch

Let's implement momentum from scratch and see the acceleration in numbers. We minimize f(w)=w2f(w) = w^2 starting at w=5w = 5. The gradient is dfdw=2w\frac{df}{dw} = 2w. We use the same learning rate (0.01) for both vanilla SGD and momentum, so the only difference is the velocity buffer.

NumPy \u2014 SGD vs Momentum on f(w) = w\u00b2
🐍sgd_vs_momentum.py
1import numpy as np

NumPy provides fast numerical arrays and math operations. We use it here for the gradient computation, though on this simple 1D example plain Python would suffice. In real neural networks, NumPy (or PyTorch) handles millions of weight updates in vectorized C code.

EXECUTION STATE
numpy = Library for numerical computing — ndarray type, vectorized math, linear algebra
3# Minimize f(w) = w² starting at w=5

Our toy loss function is f(w) = w². It has a single minimum at w = 0 where f(0) = 0. Starting at w = 5, the loss is f(5) = 25. The goal: get w as close to 0 as possible in few steps.

EXECUTION STATE
f(w) = w² = A simple parabola. Minimum at w=0. Starting loss: f(5) = 25. This is the simplest possible loss landscape — one weight, one valley.
4# Gradient: df/dw = 2w

The derivative of w² is 2w. At w=5, the gradient is 10 — it says ‘increasing w would increase the loss by ~10 per unit’. The minus sign in SGD makes us go in the OPPOSITE direction (decrease w).

EXECUTION STATE
df/dw = 2w = At w=5: gradient = 10 (steep, loss rises fast if w increases). At w=1: gradient = 2 (gentle). At w=0: gradient = 0 (minimum, no slope).
5# Minimum: w* = 0

The optimal weight is w* = 0 where the loss is zero. Both SGD and momentum aim to reach this point. The question is: how fast?

7# === Vanilla SGD ===

First we run vanilla (plain) SGD: the simplest optimizer that updates weights by subtracting lr × gradient. No memory of past gradients — each step only uses the current gradient.

8w_sgd = 5.0 — Starting weight

We initialize the weight at 5.0, far from the minimum at 0. The initial loss is f(5) = 25. This simulates starting training with random weights that produce a high loss.

EXECUTION STATE
w_sgd = 5.0 = Starting weight for vanilla SGD. Loss = 5² = 25. Distance from minimum: 5.0 units.
9lr = 0.01 — Learning rate

The learning rate is 0.01 — each step, the weight changes by 1% of the gradient. With gradient 2w: the update is 0.01 × 2w = 0.02w. So each step, w shrinks by 2%. That means after 8 steps: w = 5 × 0.98⁸ = 4.25. Barely moved!

EXECUTION STATE
lr = 0.01 = Learning rate. Effective per-step decay: w_{new} = w × (1 - lr×2) = 0.98w. At this rate, w halves after ~35 steps. Much too slow!
→ Why so small? = In real networks, large lr causes divergence because some directions have high curvature. We are forced to use a small lr to stay stable — but that makes convergence painfully slow in low-curvature directions.
11print("=== Vanilla SGD (lr=0.01) ===")

Prints a header to separate the vanilla SGD output from the momentum output below.

12for step in range(8): — 8 update steps

Run 8 gradient descent steps. Each step: compute gradient, update weight. Watch how slowly w moves toward 0 with vanilla SGD.

LOOP TRACE · 8 iterations
step=0
w before = 5.0000 → grad=10.0, update=-0.1 → w=4.9000
step=1
w before = 4.9000 → grad=9.8, update=-0.098 → w=4.8020
step=2
w before = 4.8020 → grad=9.604, update=-0.096 → w=4.7060
step=3
w before = 4.7060 → grad=9.412, update=-0.094 → w=4.6119
step=4
w before = 4.6119 → grad=9.224, update=-0.092 → w=4.5197
step=5
w before = 4.5197 → grad=9.039, update=-0.090 → w=4.4293
step=6
w before = 4.4293 → grad=8.859, update=-0.089 → w=4.3407
step=7
w before = 4.3407 → grad=8.681, update=-0.087 → w=4.2539
13grad = 2 * w_sgd — Compute gradient

The gradient of f(w) = w² is df/dw = 2w. This tells us the slope: positive gradient means ‘increasing w increases loss’, so we should decrease w. Notice the gradient gets SMALLER each step (10 → 9.8 → 9.6...) because w is moving toward 0.

EXECUTION STATE
grad at step 0 = 2 × 5.0 = 10.0 — steep slope, big push
grad at step 7 = 2 × 4.3407 = 8.681 — only slightly smaller after 7 steps!
→ Problem = The gradient barely changes because w barely moves. Each step only subtracts ~0.09, so w shrinks by ~2% per step.
14w_sgd = w_sgd - lr * grad — The update rule

This is the vanilla SGD update: w_{new} = w_{old} - η × gradient. We move opposite to the gradient (minus sign). The step size is lr × gradient. Since lr=0.01 and grad≈10, the step is about 0.1 per iteration. After 8 steps, w barely moved from 5 to 4.25.

EXECUTION STATE
— (subtraction) = We go OPPOSITE to the gradient. Positive gradient → decrease w. This moves us downhill on the loss surface.
Step 0 update = 5.0 - 0.01×10.0 = 5.0 - 0.1 = 4.9
Step 7 update = 4.3407 - 0.01×8.681 = 4.3407 - 0.0868 = 4.2539
↑ result after 8 steps = w = 4.2539 (only moved 0.75 from starting point 5.0). Loss: 18.10 (started at 25). Painfully slow!
15print(f"Step {step}: w={w_sgd:.4f}")

Prints the weight after each step. The output shows: Step 0: w=4.9000, Step 1: w=4.8020, ..., Step 7: w=4.2539. The tiny changes reveal the core problem: vanilla SGD is slow when the learning rate is small.

17# === SGD + Momentum ===

Now we add momentum. The key idea: maintain a ‘velocity’ that accumulates past gradients. If the gradient keeps pointing the same direction (as it does here — always toward 0), velocity builds up and we take larger effective steps.

18w_mom = 5.0 — Same starting point

Same starting weight as vanilla SGD so we can compare fairly. Both start at w=5 with loss=25.

EXECUTION STATE
w_mom = 5.0 = Starting weight for momentum SGD. Identical to vanilla SGD starting point. The only difference will be the update rule.
19v = 0.0 — Initial velocity

The velocity starts at zero. Think of this as a ball at rest on the loss surface. As gradients push the ball, it picks up speed. The velocity is an exponential moving average of past gradients — it remembers where the gradients have been pointing.

EXECUTION STATE
v = 0.0 = Initial velocity (momentum buffer). No accumulated gradient yet. After step 0 it becomes 10.0 (= the first gradient). After step 1 it becomes 18.8 (= 0.9×10 + 9.8). It keeps growing!
→ Physics analogy = v is like the speed of a ball rolling downhill. Each gradient is like gravity pulling the ball. The ball accelerates — unlike vanilla SGD where each step is independent.
20β = 0.9 — Momentum coefficient

beta controls how much past velocity carries forward. At β=0.9, the velocity retains 90% of its previous value plus the new gradient. This means the velocity effectively averages the last ~10 gradients (1/(1-0.9) = 10). Higher β = smoother, longer memory but slower to change direction.

EXECUTION STATE
β = 0.9 = Momentum coefficient. 0.9 is the most common value in practice. Effective window: ~1/(1-β) = 10 gradients. The velocity accumulates to ~10× the gradient in steady state.
→ β = 0 = No momentum → identical to vanilla SGD
→ β = 0.99 = Very heavy momentum → averages ~100 gradients. Smoother but slow to respond to changes.
22print("\n=== SGD + Momentum ===")

Prints a header for the momentum output section.

23for step in range(8): — 8 update steps with momentum

Same 8 steps as vanilla SGD. Watch how the velocity v builds up from 10 to 45, making each weight update progressively larger. After 8 steps, momentum reaches w=2.44 vs vanilla’s w=4.25.

LOOP TRACE · 8 iterations
step=0
grad=10.0, v: 0→10.0 = w: 5.000 → 4.900 (same as vanilla on step 0)
step=1
grad=9.8, v: 10→18.8 = w: 4.900 → 4.712 (vanilla: 4.802 — already pulling ahead!)
step=2
grad=9.424, v: 18.8→26.344 = w: 4.712 → 4.449 (vanilla: 4.706)
step=3
grad=8.897, v: 26.3→32.607 = w: 4.449 → 4.123 (vanilla: 4.612)
step=4
grad=8.245, v: 32.6→37.591 = w: 4.123 → 3.747 (vanilla: 4.520)
step=5
grad=7.493, v: 37.6→41.325 = w: 3.747 → 3.333 (vanilla: 4.429)
step=6
grad=6.667, v: 41.3→43.859 = w: 3.333 → 2.895 (vanilla: 4.341)
step=7
grad=5.790, v: 43.9→45.263 = w: 2.895 → 2.442 (vanilla: 4.254)
24grad = 2 * w_mom — Same gradient formula

The gradient is computed identically to vanilla SGD. The difference is what we DO with it. Instead of directly subtracting lr×grad from w, we first add it to the velocity buffer.

EXECUTION STATE
grad at step 0 = 2 × 5.0 = 10.0 (identical to vanilla)
grad at step 7 = 2 × 2.895 = 5.790 (much smaller than vanilla’s 8.681 because w has moved further!)
25v = β * v + grad — THE momentum update

This is the heart of momentum. The velocity is an exponential moving average of gradients: v_new = 0.9 × v_old + current_gradient. Each step, 90% of the old velocity persists and the new gradient is added. If gradients consistently point the same way, v accumulates and grows larger than any single gradient.

EXECUTION STATE
📚 Exponential moving average = v_t = β·v_{t-1} + g_t. Unrolling: v_t = g_t + 0.9·g_{t-1} + 0.81·g_{t-2} + 0.729·g_{t-3} + ... Recent gradients weigh more, old ones decay exponentially.
Step 0: v = 0.9×0 + 10.0 = = 10.0 — No history yet, v = gradient (same as vanilla)
Step 1: v = 0.9×10 + 9.8 = = 18.8 — 90% of old velocity (9.0) PLUS new gradient (9.8). Already 1.9× the gradient!
Step 4: v = 0.9×32.6 + 8.245 = = 37.6 — Velocity is now 4.6× the current gradient. Momentum is really moving!
Step 7: v = 0.9×43.9 + 5.790 = = 45.3 — Velocity is 7.8× the current gradient. Approaching the steady-state of ~10× (= 1/(1-β)).
→ Why it accelerates = Gradients all point left (‘decrease w’). Each step adds more push in the same direction. v snowballs like a rolling ball gaining speed on a slope.
26w_mom = w_mom - lr * v — Weight update with velocity

The weight update uses velocity instead of just the gradient. Since v is much larger than a single gradient, the effective step is much bigger. At step 7: lr×v = 0.01×45.3 = 0.453, while vanilla’s step is lr×grad = 0.01×8.7 = 0.087. Momentum takes a 5× bigger step!

EXECUTION STATE
Step 0: 5.0 - 0.01×10.0 = = 4.9 (step size 0.10 — same as vanilla)
Step 1: 4.9 - 0.01×18.8 = = 4.712 (step size 0.188 — already 2× vanilla!)
Step 4: 4.123 - 0.01×37.6 = = 3.747 (step size 0.376 — 4× vanilla!)
Step 7: 2.895 - 0.01×45.3 = = 2.442 (step size 0.453 — 5× vanilla!)
↑ result after 8 steps = w = 2.442 (moved 2.56 from start). Vanilla only moved 0.75. Momentum is 3.4× faster!
27print(f"Step {step}: v={v:.3f} w={w_mom:.4f}")

Prints velocity and weight after each step. The output reveals momentum’s acceleration: Step 0: v=10.0, w=4.9000 | Step 3: v=32.6, w=4.1225 | Step 7: v=45.3, w=2.4421.

29print(f"\nAfter 8 steps:")

Prints the final comparison.

30print(f" Vanilla SGD: w = {w_sgd:.4f}")

Vanilla SGD result: w = 4.2539. The weight barely moved from 5.0 to 4.25. Loss dropped from 25 to 18.1 — only a 28% reduction in 8 steps.

EXECUTION STATE
w_sgd = 4.2539 = Moved only 0.746 from starting point 5.0. Remaining loss: 18.10. At this rate, reaching w=0.5 takes ~115 steps!
31print(f" Momentum: w = {w_mom:.4f}")

Momentum result: w = 2.4421. Moved 2.56 units from starting point — 3.4× further than vanilla SGD! Loss dropped from 25 to 5.96 — a 76% reduction. Momentum reaches w=0.5 in about 35 steps vs vanilla’s 115.

EXECUTION STATE
w_mom = 2.4421 = Moved 2.558 from starting point 5.0. Remaining loss: 5.96. 3.4× more progress than vanilla SGD in the same number of steps.
→ Why the speedup? = Velocity accumulated to ~45 while each gradient was only ~6–8. The rolling ball analogy: the ball builds speed on a consistent downhill slope, taking ever-larger steps.
6 lines without explanation
1import numpy as np
2
3# ── Minimize f(w) = w² starting at w=5 ──
4# Gradient: df/dw = 2w
5# Minimum: w* = 0
6
7# === Vanilla SGD ===
8w_sgd = 5.0
9lr = 0.01
10
11print("=== Vanilla SGD (lr=0.01) ===")
12for step in range(8):
13    grad = 2 * w_sgd
14    w_sgd = w_sgd - lr * grad
15    print(f"Step {step}: w={w_sgd:.4f}")
16
17# === SGD + Momentum ===
18w_mom = 5.0
19v = 0.0
20beta = 0.9
21
22print("\n=== SGD + Momentum (lr=0.01, β=0.9) ===")
23for step in range(8):
24    grad = 2 * w_mom
25    v = beta * v + grad
26    w_mom = w_mom - lr * v
27    print(f"Step {step}: v={v:.3f}  w={w_mom:.4f}")
28
29print(f"\nAfter 8 steps:")
30print(f"  Vanilla SGD: w = {w_sgd:.4f}")
31print(f"  Momentum:    w = {w_mom:.4f}")

The results tell the story:

MetricVanilla SGDMomentum (β=0.9)
w after 8 steps4.25392.4421
Distance traveled0.7462.558
Loss after 8 steps18.105.96
Loss reduction28%76%

Momentum moved 3.4× further in the same number of steps. The velocity built up from 10 to 45.3, making each step progressively larger. This is the power of accumulated momentum in a consistent direction.

Why the same learning rate? We used lr=0.01 for both to show that momentum accelerates convergence without increasing the learning rate. In practice, you might also tune the learning rate when switching to momentum, but the acceleration is real even with identical lr.

PyTorch: Training with Momentum

Now let's apply momentum to real neural network training. We train our diagonal flip network from Chapters 7–8 using both vanilla SGD and SGD with momentum, on all 16 binary images. The models start with identical random weights. The only difference is one line: momentum=0.9\texttt{momentum=0.9}.

PyTorch \u2014 SGD vs SGD+Momentum on the Diagonal Flip Network
🐍train_with_momentum.py
1import torch

PyTorch is the deep learning framework we use throughout this book. It provides tensors (GPU-accelerated arrays), automatic differentiation (autograd), and neural network building blocks (nn.Module, nn.Linear, optimizers).

EXECUTION STATE
torch = PyTorch library — tensors, autograd, nn modules, optimizers. GPU support built in.
2import torch.nn as nn

torch.nn contains neural network layers (Linear, Conv2d), activation functions (ReLU, Sigmoid), loss functions (MSELoss, CrossEntropyLoss), and the base Module class. We alias it as nn for brevity.

EXECUTION STATE
torch.nn = Neural network module — provides nn.Linear, nn.ReLU, nn.Module, etc. All layers in PyTorch inherit from nn.Module.
4torch.manual_seed(42) — Reproducibility

Sets the random number generator seed so that weight initialization is identical each time we run the code. Both models will start with the exact same random weights, making the comparison fair.

EXECUTION STATE
📚 torch.manual_seed() = PyTorch function: seeds the random number generator for reproducibility. Affects torch.randn(), weight initialization in nn.Linear, etc. Seed 42 is a convention (from Hitchhiker’s Guide).
→ Why 42? = Any integer works. 42 is tradition. The key: both models see the same seed, so they start with identical weights.
7class FlipNet(nn.Module): — Our network

Defines the same 4→3→4 network from Chapters 7–8: 4 input neurons (flattened 2×2 image), 3 hidden neurons with ReLU, 4 output neurons. The goal: learn the diagonal flip (matrix transpose) operation.

EXECUTION STATE
⬇ inherits: nn.Module = PyTorch’s base class for all neural networks. Provides parameter tracking, .forward(), .parameters(), device management, serialization.
→ Architecture = Input(4) → Linear(4,3) → ReLU → Linear(3,4) → Output(4). Total: 4×3+3 + 3×4+4 = 31 parameters.
8def __init__(self):

Constructor: creates the layers. Called once when we instantiate FlipNet(). The layers are stored as attributes so PyTorch can find and track their parameters.

9super().__init__() — Initialize nn.Module

Calls the parent class (nn.Module) constructor. This sets up PyTorch’s internal bookkeeping: parameter registration, hook management, training/eval mode tracking. Must be called before creating any layers.

10self.layer1 = nn.Linear(4, 3) — Hidden layer

Creates a fully connected layer: 4 inputs → 3 outputs. Internally stores a weight matrix W(3×4) and bias vector b(3). Forward: output = input @ W.T + b. That’s 4×3 + 3 = 15 learnable parameters.

EXECUTION STATE
📚 nn.Linear(in_features, out_features) = PyTorch module: y = x @ W.T + b. Creates weight matrix (out×in) and bias vector (out). Parameters initialized with Kaiming uniform by default.
⬇ arg 1: in_features = 4 = Input dimension. Each input is a flattened 2×2 image = 4 pixels.
⬇ arg 2: out_features = 3 = Output dimension. 3 hidden neurons. Chosen to be smaller than input (compression). This forces the network to learn a compact representation.
11self.layer2 = nn.Linear(3, 4) — Output layer

Output layer: 3 hidden neurons → 4 output pixels. Weight matrix W(4×3), bias b(4). Total: 3×4 + 4 = 16 parameters. Combined with layer1: 31 total parameters.

EXECUTION STATE
⬇ arg 1: in_features = 3 = Matches the output of layer1. The 3 hidden activations are the input here.
⬇ arg 2: out_features = 4 = 4 output pixels — the flipped 2×2 image, flattened.
12self.relu = nn.ReLU() — Activation function

ReLU: f(x) = max(0, x). Applied between layer1 and layer2. Without activation, two linear layers collapse into one (matrix multiply is associative). ReLU adds non-linearity that lets the network learn the diagonal flip.

EXECUTION STATE
📚 nn.ReLU() = Rectified Linear Unit. max(0, x) for each element. Negative values become 0, positive values pass through unchanged. No learnable parameters.
14def forward(self, x): — Forward pass

Defines the forward pass: input → layer1 → ReLU → layer2 → output. Called automatically when you do model(x). PyTorch’s autograd records all operations for the backward pass.

EXECUTION STATE
⬇ input: x = A 4-element tensor: one flattened 2×2 image. Example: [1,0,1,1] for [[1,0],[1,1]].
⬆ returns = A 4-element tensor: the predicted flipped image. Should approach Y[i] = X[i][:, [0,2,1,3]] after training.
15return self.layer2(self.relu(self.layer1(x)))

Composes the three operations: layer1(x) computes the weighted sum, relu() zeros out negatives, layer2() produces the final output. Reading inside-out: layer1 first, then relu, then layer2.

EXECUTION STATE
self.layer1(x) = x(4) @ W1.T(4×3) + b1(3) = z1(3). The pre-activation hidden values.
self.relu(...) = max(0, z1) = h(3). Zeros out negative hidden values.
self.layer2(...) = h(3) @ W2.T(3×4) + b2(4) = y_hat(4). The prediction.
18# Dataset: all 16 binary 2×2 images

We generate all possible 2×2 binary images (0000, 0001, ..., 1111) and their diagonal flips as training data. 16 examples total — tiny by real-world standards but enough to demonstrate optimizer behavior.

19X = torch.tensor([...]) — All 16 inputs

Creates a 16×4 tensor where each row is one flattened 2×2 binary image. Row 0: [0,0,0,0], Row 5: [0,1,0,1], Row 15: [1,1,1,1]. These are all possible inputs to our network.

EXECUTION STATE
X shape = (16, 4) — 16 images, 4 pixels each
X[0] = [0, 0, 0, 0] — all-black image
X[5] = [0, 1, 0, 1]
X[15] = [1, 1, 1, 1] — all-white image
21Y = X[:, [0, 2, 1, 3]] — Diagonal flips

The diagonal flip of a 2×2 matrix [[a,b],[c,d]] is [[a,c],[b,d]] — the transpose. Flattened: [a,b,c,d] → [a,c,b,d]. So we swap columns 1 and 2 (0-indexed). PyTorch fancy indexing does this: X[:, [0,2,1,3]] keeps columns 0,3 and swaps 1↔2.

EXECUTION STATE
📚 X[:, [0, 2, 1, 3]] = PyTorch fancy indexing: select columns in the order [0, 2, 1, 3]. Column 0 stays, column 1 → position 2, column 2 → position 1, column 3 stays. This is the transpose of the 2×2 matrix.
Example: X[6]=[0,1,1,0] = Y[6] = [0,1,1,0][:,[0,2,1,3]] = [0,1,1,0]. The matrix [[0,1],[1,0]] transposed is [[0,1],[1,0]] — same! (It’s symmetric.)
23# Train with vanilla SGD

First we train with vanilla SGD (no momentum). We set the seed again so both models start with identical random weights.

24torch.manual_seed(42) — Reset seed for fair comparison

Re-seeding ensures model_sgd and model_mom start with the EXACT same weights. Without this, different random initialization would confound the optimizer comparison.

25model_sgd = FlipNet() — Vanilla SGD model

Creates the network for vanilla SGD training. Because we just set seed=42, the weight initialization is deterministic and identical to the momentum model created below.

26opt_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.05)

Creates a vanilla SGD optimizer: no momentum, no weight decay. lr=0.05 means each weight changes by 5% of its gradient per step. model.parameters() passes all 31 learnable parameters to the optimizer.

EXECUTION STATE
📚 torch.optim.SGD() = PyTorch’s SGD optimizer. When momentum=0 (default): w = w - lr × grad. This is the simplest optimizer — the baseline we compare against.
⬇ arg 1: model_sgd.parameters() = A generator yielding all learnable tensors: layer1.weight(3×4), layer1.bias(3), layer2.weight(4×3), layer2.bias(4). 31 parameters total.
⬇ arg 2: lr = 0.05 = Learning rate. Larger than the 0.01 we used in the 1D example because neural network gradients tend to be smaller (loss is averaged over outputs).
28# Train with SGD + Momentum

Now we create an identical network but train it with momentum. The ONLY difference is the optimizer.

29torch.manual_seed(42) — Same starting weights

Reset seed to 42 again. This guarantees model_mom has the exact same initial weights as model_sgd. Any difference in training will be purely due to the optimizer.

30model_mom = FlipNet() — Momentum model

Creates a second network instance. Thanks to manual_seed(42), this has identical weights to model_sgd. We will compare how quickly each optimizer reduces the loss.

31opt_mom = torch.optim.SGD(..., momentum=0.9) — THE key change

The ONLY difference from vanilla SGD: momentum=0.9. PyTorch maintains a velocity buffer for each parameter. Each step: v = 0.9*v + grad, then w = w - lr*v. This single flag transforms a slow optimizer into a fast one.

EXECUTION STATE
📚 torch.optim.SGD(params, lr, momentum) = With momentum > 0: maintains a velocity buffer v for each parameter. Update rule: v_t = momentum × v_{t-1} + grad_t, then w = w - lr × v_t. Identical to our NumPy implementation above.
⬇ momentum = 0.9 = The momentum coefficient β. 0.9 is the standard default for SGD+momentum. Effectively averages the last ~10 gradients. This is the same β we used in the NumPy example.
→ Other common values = β=0.0: vanilla SGD. β=0.9: standard momentum. β=0.99: heavy momentum (used in some large-scale training).
34for epoch in range(30): — Training loop

Train both models for 30 epochs. Each epoch: iterate over all 16 images, compute loss, backprop, update. We run both models in the same loop to ensure they see the same data in the same order.

EXECUTION STATE
epoch = One complete pass through all 16 training images. 30 epochs = 30 × 16 = 480 gradient updates per model.
35loss_sgd = loss_mom = 0.0 — Reset epoch loss

Accumulate the loss over all 16 images so we can report the average at the end of each epoch. Reset to 0 at the start of each epoch.

36for i in range(16): — Iterate over all images

Process each of the 16 training images one at a time (batch size = 1). This is true stochastic gradient descent: the gradient from each single image is noisy, providing different information than the average over all images.

EXECUTION STATE
Batch size = 1 = Each gradient update uses one image. This makes gradients noisy but allows 16 updates per epoch instead of 1. The noise is exactly what momentum helps smooth out.
38pred = model_sgd(X[i]) — Vanilla forward pass

Runs the forward pass through the vanilla SGD model. Calling model_sgd(X[i]) invokes the forward() method: layer1 → ReLU → layer2. Returns a 4-element prediction tensor.

39lv = torch.mean((pred - Y[i]) ** 2) — MSE loss

Mean squared error between prediction and target. Same loss function from Chapter 8. Example: if pred=[0.1, -0.1, -0.05, -0.25] and target=[1, 1, 0, 1], loss = mean of [0.81, 1.21, 0.0025, 1.5625] = 0.896.

EXECUTION STATE
📚 torch.mean() = Computes the mean of all elements in a tensor. Here it averages the 4 squared errors into a single scalar loss.
40opt_sgd.zero_grad() — Clear old gradients

PyTorch accumulates gradients by default (for gradient accumulation techniques). We must zero them before each backward pass, otherwise gradients from different images would add up incorrectly.

EXECUTION STATE
📚 optimizer.zero_grad() = Sets .grad to zero for all parameters managed by this optimizer. Must be called before loss.backward() to prevent stale gradient accumulation.
41lv.backward() — Compute gradients

Runs backpropagation: computes dL/dw for all 31 parameters using the chain rule, exactly as we did by hand in Chapter 8. After this call, each parameter tensor has a .grad attribute containing its gradient.

EXECUTION STATE
📚 tensor.backward() = PyTorch autograd: traverses the computation graph backwards, computing gradients for every parameter that contributed to this loss. Uses the chain rule automatically.
42opt_sgd.step() — Update weights (vanilla SGD)

Applies the vanilla SGD update to all 31 parameters: w = w - 0.05 × grad. No momentum buffer — just the raw gradient times the learning rate. Each parameter is updated independently.

EXECUTION STATE
📚 optimizer.step() = Applies one optimization step. For vanilla SGD: w = w - lr × w.grad. For momentum SGD: v = β×v + w.grad, then w = w - lr×v.
43loss_sgd += lv.item() — Track epoch loss

Adds this image’s loss to the running total. .item() converts a single-element tensor to a Python float. We divide by 16 at the end to get the average loss per image.

EXECUTION STATE
📚 tensor.item() = Extracts the scalar value from a 0-dimensional tensor. Returns a Python float. Used for logging because you can’t print PyTorch tensors in f-strings cleanly.
45pred = model_mom(X[i]) — Momentum forward pass

Same forward pass, but through the momentum model. The architectures are identical — only the weights differ because they were updated differently in previous steps.

46lm = torch.mean((pred - Y[i]) ** 2) — Momentum MSE loss

Computes MSE loss for the momentum model. As training progresses, this loss should decrease faster than the vanilla SGD loss because momentum smooths out noisy gradients and accelerates consistent directions.

47opt_mom.zero_grad() — Clear momentum model gradients

Same as vanilla: zero out gradients before backward. Note that zero_grad() does NOT clear the momentum velocity buffer — that persists across steps, which is the whole point of momentum.

EXECUTION STATE
→ Key distinction = zero_grad() clears .grad (the current gradient). The momentum buffer v (stored internally by the optimizer) is NOT affected. v persists and accumulates across steps.
48lm.backward() — Compute momentum model gradients

Backpropagation through the momentum model. The gradient computation is identical to vanilla SGD — backprop doesn’t know or care about the optimizer. The optimizer only affects how gradients are USED in the update step.

49opt_mom.step() — Update weights WITH momentum

This is where the magic happens. For each of the 31 parameters: v = 0.9×v + grad, then w = w - 0.05×v. The velocity buffer v remembers past gradients, giving larger effective steps in consistent directions.

EXECUTION STATE
→ Internally, for each parameter = 1. v_new = 0.9 × v_old + param.grad 2. param.data -= 0.05 × v_new The velocity buffer v is stored per-parameter inside the optimizer.
50loss_mom += lm.item() — Track momentum epoch loss

Track momentum model loss. At the end of each epoch, loss_mom/16 gives the average loss per image for the momentum model.

51if epoch % 5 == 0: print(...) — Print every 5 epochs

Prints the average loss every 5 epochs to compare convergence speed. You should see momentum’s loss dropping significantly faster than vanilla SGD, especially in the first 15–20 epochs. The momentum model reaches low loss in roughly half the epochs.

EXECUTION STATE
Expected output pattern = Epoch 0: SGD=0.42 Momentum=0.42 (same at start) Epoch 10: SGD=0.18 Momentum=0.08 (momentum 2× faster) Epoch 25: SGD=0.09 Momentum=0.02 (momentum converged)
14 lines without explanation
1import torch
2import torch.nn as nn
3
4torch.manual_seed(42)
5
6# ── Network (same architecture as Chapters 7-8) ──
7class FlipNet(nn.Module):
8    def __init__(self):
9        super().__init__()
10        self.layer1 = nn.Linear(4, 3)
11        self.layer2 = nn.Linear(3, 4)
12        self.relu = nn.ReLU()
13
14    def forward(self, x):
15        return self.layer2(self.relu(self.layer1(x)))
16
17# ── Dataset: all 16 binary 2×2 images ──
18X = torch.tensor([[int(b) for b in f"{i:04b}"]
19                   for i in range(16)], dtype=torch.float32)
20Y = X[:, [0, 2, 1, 3]]  # diagonal flip
21
22# ── Train with vanilla SGD ──
23torch.manual_seed(42)
24model_sgd = FlipNet()
25opt_sgd = torch.optim.SGD(model_sgd.parameters(), lr=0.05)
26
27# ── Train with SGD + Momentum ──
28torch.manual_seed(42)
29model_mom = FlipNet()
30opt_mom = torch.optim.SGD(model_mom.parameters(),
31                           lr=0.05, momentum=0.9)
32
33for epoch in range(30):
34    loss_sgd = loss_mom = 0.0
35    for i in range(16):
36        # --- Vanilla SGD ---
37        pred = model_sgd(X[i])
38        lv = torch.mean((pred - Y[i]) ** 2)
39        opt_sgd.zero_grad()
40        lv.backward()
41        opt_sgd.step()
42        loss_sgd += lv.item()
43        # --- Momentum ---
44        pred = model_mom(X[i])
45        lm = torch.mean((pred - Y[i]) ** 2)
46        opt_mom.zero_grad()
47        lm.backward()
48        opt_mom.step()
49        loss_mom += lm.item()
50    if epoch % 5 == 0:
51        print(f"Epoch {epoch:2d}:  SGD={loss_sgd/16:.4f}"
52              f"  Momentum={loss_mom/16:.4f}")

The difference is dramatic. With momentum, the network reaches low loss in roughly half the epochs. The momentum buffer v\mathbf{v} for each of the 31 parameters accumulates gradient information across the 16 training images, effectively remembering "which way is downhill on average" rather than being yanked around by individual images.

PyTorch's SGD implementation uses the same formula we derived: vt=βvt1+gtv_t = \beta \cdot v_{t-1} + g_t, then wt+1=wtηvtw_{t+1} = w_t - \eta \cdot v_t. No surprises. The velocity buffer is stored internally by the optimizer and persists across .step()\texttt{.step()} calls. Calling .zero_grad()\texttt{.zero\_grad()} clears .grad\texttt{.grad} but does NOT clear the velocity buffer.

Nesterov Accelerated Gradient

In 1983, Yurii Nesterov proposed a subtle but powerful improvement to momentum. The idea: instead of computing the gradient at the current position, compute it at the position you're about to jump to.

Standard Momentum vs Nesterov

Standard momentum: "I'm at position w\mathbf{w}. What's the gradient here? OK, combine it with my velocity and move."

Nesterov momentum: "My velocity will carry me to roughly wηβv\mathbf{w} - \eta \beta \mathbf{v}. What's the gradient there? That's a better gradient because it accounts for where I'm actually going."

The Nesterov update:

vt=βvt1+L(wtηβvt1)\mathbf{v}_t = \beta \cdot \mathbf{v}_{t-1} + \nabla L(\mathbf{w}_t - \eta \beta \mathbf{v}_{t-1})

wt+1=wtηvt\mathbf{w}_{t+1} = \mathbf{w}_t - \eta \cdot \mathbf{v}_t

This "look-ahead" gradient gives Nesterov momentum a corrective quality. If the velocity is about to overshoot the minimum, the look-ahead gradient will already be pointing back — so the velocity starts slowing down before the overshoot happens instead of after.

In PyTorch: One Flag

Enabling Nesterov momentum in PyTorch requires exactly one change:

torch.optim.SGD(params, lr=0.05, momentum=0.9, nesterov=True)\texttt{torch.optim.SGD(params, lr=0.05, momentum=0.9, nesterov=True)}

OptimizerPyTorch codeUpdate rule
Vanilla SGDSGD(params, lr=0.05)w = w - η·grad
MomentumSGD(params, lr=0.05, momentum=0.9)v = βv + grad; w = w - ηv
NesterovSGD(params, lr=0.05, momentum=0.9, nesterov=True)Look-ahead gradient then momentum

In practice, Nesterov gives a small but consistent improvement over standard momentum. It is especially helpful when the loss surface has sharp turns — Nesterov "sees around the corner" while standard momentum overshoots and has to backtrack.

When to use Nesterov: if you are already using SGD with momentum, switching to Nesterov=True is essentially free (same computational cost). Many practitioners use it by default. The improvement is modest (~5–10% faster convergence) but consistent.

Connection to Modern Training

SGD with momentum is not just a historical artifact — it remains deeply relevant in modern deep learning:

Momentum as the Foundation

Adam (Section 2) extends momentum with per-parameter adaptive learning rates. Adam maintains two exponential moving averages: the first moment (gradient mean, like momentum) and the second moment (gradient variance). Understanding momentum is essential to understanding Adam.

Large-scale training of models like GPT, LLaMA, and BERT actually uses Adam or AdamW (Adam with weight decay), which build directly on the momentum concept. The velocity buffer idea is unchanged — Adam just adds a second buffer for gradient magnitude.

SGD vs Adam in Practice

AspectSGD + MomentumAdam
Convergence speedSlower to tune, but excellent final lossFast convergence, good default settings
GeneralizationOften better on held-out dataCan overfit more easily
Hyperparameter sensitivityNeeds careful lr tuningMore robust to lr choice
Memory overhead1 buffer per parameter (v)2 buffers per parameter (m, v)
Modern usageImage classification (ResNets)NLP, transformers, most research

Interestingly, well-tuned SGD with momentum often achieves better generalization than Adam on computer vision tasks. This is why many image classification papers (ResNets, EfficientNets) still use SGD+momentum. The hypothesis: Adam's per-parameter adaptation finds sharper minima that generalize worse.

Connection to Transformer Training

When training transformers, momentum appears in several places:

  • AdamW optimizer: the standard optimizer for transformer training maintains a momentum buffer mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1-\beta_1) g_t with β1=0.9\beta_1 = 0.9 — this is exactly the exponential moving average we studied in this section
  • Learning rate warmup: transformers typically warm up the learning rate from 0 to the target value over the first 1,000–4,000 steps, then decay it. This interacts with momentum because the velocity buffer needs time to build up
  • Gradient accumulation: when training with batch sizes of 1M+ tokens (e.g. Chinchilla, LLaMA), gradients are accumulated across micro-batches. The accumulated gradient acts like a batch gradient, and the optimizer's momentum provides additional temporal smoothing on top

Summary

  1. SGD processes mini-batches instead of the full dataset, giving noisy but frequent gradient updates. Mini-batch SGD is the default in practice.
  2. The ravine problem: when curvatures differ across directions, vanilla SGD oscillates in steep directions and crawls in gentle ones. No single learning rate works well for all directions.
  3. Momentum adds a velocity buffer vt=βvt1+Lv_t = \beta v_{t-1} + \nabla L that is an exponential moving average of past gradients. Consistent directions accelerate; oscillating directions dampen.
  4. The standard setting is β=0.9\beta = 0.9, which averages the last ~10 gradients. This gives roughly a 3–5× convergence speedup over vanilla SGD.
  5. In PyTorch: just add momentum=0.9\texttt{momentum=0.9} to your SGD optimizer. For Nesterov's improvement, add nesterov=True\texttt{nesterov=True}.
  6. Momentum is the foundation of Adam, which adds per-parameter adaptive learning rates. We cover Adam in Section 2.
Looking ahead: Momentum gives every parameter the same effective learning rate (just smoothed). But what if some parameters need a larger step than others? That's the idea behind adaptive learning rate methods like Adam, which we explore in Section 2.
Loading comments...