Chapter 1
28 min read
Section 4 of 117

Automatic Differentiation and the Chain Rule

Mathematical Bedrock

Introduction

Every massive language model you have ever used — GPT-4, Claude, Gemini, Llama, DeepSeek — was trained by repeating one operation about ten million times: compute the gradient of a scalar loss with respect to every parameter, then take a small step downhill. The optimizer choice changes, the loss changes, the architecture changes, the data changes. The gradient step does not. It is the heart muscle of deep learning.

For a 70-billion-parameter model that means computing seventy billion partial derivatives per training step. By hand it is unthinkable. By finite differences it would take more compute than humanity has ever produced. The reason it is even possible — and roughly free, in the sense that the gradient costs only a constant factor more than the forward pass — is a single 1970s idea called reverse-mode automatic differentiation, built on top of one rule from first-year calculus: the chain rule.

Why this matters: Backpropagation is not a separate algorithm bolted onto neural networks. It is reverse-mode autodiff applied to a computational graph. Once you see the graph and the chain rule together, every variation — gradient checkpointing, mixed precision, FSDP, gradient accumulation, even Adam itself — becomes a memory or numerics trick around the same primitive.

1.4.1 The Real Problem: A Hundred Billion Knobs

Training a deep network is, mathematically, an enormous optimisation problem. We pick a loss function L(θ)L(\theta) that measures how badly the model is doing, where θ\theta is the vector of all the model's parameters — every weight matrix, every bias, every layer-norm scale, every embedding vector — packed end to end. For a modern frontier model θRN\theta \in \mathbb{R}^N with NN on the order of 101110^{11}.

To improve the model we need to know, for each of those 101110^{11} numbers, the answer to one question: if I nudge this one knob up by a tiny amount, does the loss go up or down, and by how much? That answer is the partial derivative L/θi\partial L / \partial \theta_i. Stack all of them into a vector and you get the gradient θLRN\nabla_\theta L \in \mathbb{R}^N.

Here is the brutal arithmetic. If you tried to estimate each partial derivative the naive way — perturb one parameter, re-run the entire forward pass, see how much the loss changed — you would do NN forward passes per gradient. For a 70B model with a forward pass that takes ~500 ms on an H100 cluster, one gradient step would take 7×10100.5s1,100  years7 \times 10^{10} \cdot 0.5\,\text{s} \approx 1{,}100\;\text{years}.

Finite differences do not scale. They are O(N)O(N) in compute and are also numerically unstable: too-small a perturbation drowns in floating-point noise, too-large a perturbation drowns in the second-order curvature.

Reverse-mode autodiff gives us the entire gradient — all 101110^{11} components — in one forward pass plus one backward pass, with the backward pass roughly the same cost as the forward. So the total cost is O(cost of L)O(\text{cost of } L), not O(Ncost of L)O(N \cdot \text{cost of } L). This is the only reason deep learning is computationally possible.


1.4.2 Four Ways to Compute a Derivative

It helps to see why reverse-mode wins by laying it next to its competitors. There are four classical approaches:

MethodWhat it doesCost for N inputs, 1 outputUsed in deep learning?
SymbolicTreat the program as an algebraic expression and apply differentiation rules to produce a closed-form formula.Formula explodes ("expression swell") — quickly unmanageable.No.
Numerical (finite differences)Perturb each input slightly, re-evaluate, divide. (L(θ+ε) − L(θ))/ε.O(N) forward passes per gradient. Numerically noisy.Only for gradient-checking.
Forward-mode autodiffCarry a 'dual number' (value, derivative) through every op; differentiate as you go.O(N) — one extra pass per input direction.Yes, but only for tiny input dimension.
Reverse-mode autodiffRecord the computation as a DAG on the forward pass, then walk it in reverse applying the chain rule.O(1) — one backward pass for the whole gradient.Yes — this is backpropagation.

For a function with many inputs and one output (which is exactly the deep-learning shape: many parameters, one loss), reverse-mode is the only viable choice. The rest of this section is dedicated to building it from first principles.


1.4.3 The Idea: Calculus on a Graph

The first idea is to stop thinking of the model as a giant tangled formula and start thinking of it as a directed acyclic graph. Every primitive operation — add, multiply, matmul, exp, log, softmax, GELU — is a node. Edges carry intermediate values. Inputs flow in at one side, the loss pops out at the other.

Take an absurdly small example: y=(Wx+b)2y = (W x + b)^2. As a graph that is five nodes:

W ── × ── u ── + ── v ── (·)² ── y
x ─────┘ b ──┘

The forward pass evaluates left to right: with W=2W = 2, x=3x = 3, b=1b = 1, we get u=6u = 6, v=7v = 7, y=49y = 49. Each node remembers two things: its output value, and a pointer to the small local rule that knows how to differentiate it.

The backward pass walks the same graph from right to left. At each node we have already received an incoming derivative from downstream (the gradient of the loss w.r.t. the node's output) and we use the local rule to convert it into derivatives w.r.t. the node's inputs. Repeat until we hit the leaves. Done.

The mental model: forward = a fluid filling the graph from inputs to loss; backward = the same fluid running the other way carrying gradient instead of value. Nothing about the graph itself changes; only the direction of flow.

1.4.4 The Chain Rule, Symbol by Symbol

The local rule at every node is the chain rule. In its scalar form:

dLdx  =  dLdududx.\frac{dL}{dx} \;=\; \frac{dL}{du}\,\cdot\,\frac{du}{dx}.

Read it like a relay race. dL/du\,dL/du\, is the baton handed to us by the downstream runner — we never compute it, we just receive it. du/dx\,du/dx\, is the local derivative we do know because we wrote the op (it is just the derivative of this one primitive). Multiplying the two gives the baton we hand to the next runner upstream.

For an op with multiple inputs the rule generalises to a sum: if u=f(x1,x2,,xk)u = f(x_1, x_2, \dots, x_k), then

dLdxi  =  dLduuxi.\frac{dL}{dx_i} \;=\; \frac{dL}{du}\,\cdot\,\frac{\partial u}{\partial x_i}.

And for the case where one variable feeds many downstream ops — say xx is read by both u1u_1 and u2u_2 — the rule becomes a sum over children, the multivariable chain rule:

dLdx  =  jdLdujujx.\frac{dL}{dx} \;=\; \sum_{j}\,\frac{dL}{du_j}\,\cdot\,\frac{\partial u_j}{\partial x}.

That single equation is the whole engine. Every other rule — d(a+b)=da+dbd(a+b) = da+db, d(ab)=adb+bdad(ab) = a\,db + b\,da, d(σ(z))=σ(z)(1σ(z))dzd(\sigma(z)) = \sigma(z)(1-\sigma(z))\,dz, the matrix calculus identities behind matmul — is just a way of specialising it to one primitive at a time.

For vector-valued nodes (which is the normal case in a real network) the same rule holds, with the local derivative now a Jacobian u/xRm×n\,\partial u/\partial x \in \mathbb{R}^{m \times n}\,and the upstream baton a vector. We almost never materialise these Jacobians explicitly — instead we use vector-Jacobian products vJv^\top J, which is exactly what the local backward closures compute.


1.4.5 Forward Mode vs Reverse Mode

Both modes apply the chain rule; they differ only in the order in which they multiply the Jacobians. Suppose the network is the composition L=fKfK1f1(θ)L = f_K \circ f_{K-1} \circ \cdots \circ f_1(\theta)with Jacobians J1,J2,,JKJ_1, J_2, \dots, J_K. The total Jacobian is the product J=JKJK1J1J = J_K\,J_{K-1}\,\cdots\,J_1. Matrix product is associative, so we can multiply from either end:

  1. Forward mode computes (JK(JK1((J1ei))))(J_K\,(J_{K-1}\,(\cdots\,(J_1\,e_i)))) — picks one input direction eie_i and pushes it forward through every layer. Cost: one full forward-style pass per input.
  2. Reverse mode computes (((vJK)JK1))J1(((v^\top\,J_K)\,J_{K-1})\,\cdots)\,J_1 — starts from a cotangent on the output (for a scalar loss this is just the number 1) and pulls it back through every layer. Cost: one full backward-style pass per output.

Deep learning gives us many inputs (parameters) and one output (loss). Forward mode would cost one pass per parameter — back to the O(N)O(N) nightmare. Reverse mode costs one pass for the entire gradient. That asymmetry is the whole game.

Reverse mode buys cheap gradients with expensive memory. To run backward we must keep every forward intermediate value around. For a 70B transformer on a 4096-token sequence, those activations weigh hundreds of gigabytes. Almost every modern training trick — gradient checkpointing, ZeRO/FSDP, mixed precision — is a way to soften that memory cost without losing the compute win. We will see why in §1.4.11.

1.4.6 Interactive: Build a Computational Graph

Pick an example below and step through the forward and backward passes one node at a time. Watch how the values fill in on the way forward and the gradients fill in on the way back. The arrows on the edges flip direction between the two phases — that flip is literally the difference between value flow and gradient flow.

Loading computational graph builder…

1.4.7 Interactive: Watch the Chain Rule Multiply

Pick a composed function such as y=(3x+1)2y = (3x+1)^2 or y=sin(x2)y = \sin(x^2) and drag the input. The visualiser shows each intermediate function and its local derivative; the final gradient at the bottom is the product of all the local derivatives. That product is the chain rule, made visible.

Loading chain-rule visualiser…

1.4.8 Manual Numerical Walkthrough

Time to derive a gradient with a pencil. We will use the same toy as the graph above: y=(Wx+b)2y = (W x + b)^2 with W=2W = 2, x=3x = 3, b=1b = 1, and we will treat yy itself as the loss. The numbers are small enough to fit in your head, but the structure is exactly the same as a billion-parameter run — only the dimensions change.

Manual Numerical Walkthrough — open to see every number

Set up the graph. Three leaves (W, x, b) feed three ops (multiply, add, square) producing three intermediates (u, v, y).

NodeFormulaValueLocal derivative w.r.t. each input
Wleaf2
xleaf3
bleaf1
uW · x6∂u/∂W = x = 3, ∂u/∂x = W = 2
vu + b7∂v/∂u = 1, ∂v/∂b = 1
y49∂y/∂v = 2v = 14

Forward pass. Plug values left to right: u=23=6u = 2 \cdot 3 = 6, v=6+1=7v = 6 + 1 = 7, y=72=49y = 7^2 = 49. Done. Every intermediate value is now memorised by the graph.

Seed the backward pass. The loss is yy itself, so dL/dy=1dL/dy = 1. This is the cotangent that we push backward.

Step 1 — through the square. Local rule: dy/dv=2v=14dy/dv = 2v = 14. Apply the chain rule:

dLdv=dLdydydv=114=14.\frac{dL}{dv} = \frac{dL}{dy} \cdot \frac{dy}{dv} = 1 \cdot 14 = 14.

Step 2 — through the addition. Local rules: dv/du=1dv/du = 1, dv/db=1dv/db = 1. Two outgoing arrows split the gradient evenly:

dLdu=dLdv1=14,dLdb=dLdv1=14.\frac{dL}{du} = \frac{dL}{dv} \cdot 1 = 14,\qquad \frac{dL}{db} = \frac{dL}{dv} \cdot 1 = 14.

Step 3 — through the multiplication. Local rules: du/dW=x=3du/dW = x = 3, du/dx=W=2du/dx = W = 2. So

dLdW=dLdux=143=42,dLdx=dLduW=142=28.\frac{dL}{dW} = \frac{dL}{du} \cdot x = 14 \cdot 3 = 42,\qquad \frac{dL}{dx} = \frac{dL}{du} \cdot W = 14 \cdot 2 = 28.

Final gradients (read from the leaves): dL/dW=42dL/dW = 42, dL/dx=28dL/dx = 28, dL/db=14dL/db = 14. Three numbers, computed by three local rules and three multiplications. No giant formula, no nested derivatives.

Sanity check by hand. Expand y=(Wx+b)2y = (Wx+b)^2 symbolically: dy/dW=2(Wx+b)x=273=42dy/dW = 2(Wx+b)\cdot x = 2 \cdot 7 \cdot 3 = 42. Same number. The graph walk did exactly the symbolic derivative — but it would have been just as easy with 101110^{11} nodes instead of three.


1.4.9 Plain Python: A Tiny Autograd Engine

We will now build the entire engine from §1.4.8 in about sixty lines of Python — no PyTorch, no NumPy, just floats and closures. The class is called Value and it is deliberately written in the Karpathy-micrograd style so that every line maps to one idea from the last few subsections.

Read it once top-to-bottom for the shape, then click each line for the connection back to the math.

Tiny reverse-mode autodiff engine
🐍value.py
1Define the Value class

A Value is one node in the computational graph — a single scalar that knows how it was produced. Every primitive (add, multiply, square) returns a new Value whose .data is the result and whose ._backward knows how to push gradient to its parents.

EXAMPLE
Value(3.0)._op == '' (a leaf, no operation)
2__init__ signature

data is the forward value (a Python float). parents lists the Values that this one was computed from (empty tuple for inputs). op is a debug string so you can print the graph and tell which rule built each node.

EXAMPLE
Value(6.0, (W, x), '*')  ⇒ a node holding 6.0, born from W*x
3self.data — the forward number

The actual scalar value flowing forward through the graph. For W=2, x=3, b=1, after y = (W*x + b)**2 the .data fields along the way are 2, 3, 6, 1, 7, 49.

4self.grad — the partial derivative slot

Starts at 0.0 and is filled in by backward(). After backward(), self.grad equals dL/dself where L is whatever you called backward() on. We accumulate (+=) because in deep graphs a node can have many children, each contributing its own term to the chain-rule sum.

EXAMPLE
After y.backward() with our toy: W.grad = 42.0, b.grad = 14.0.
5self._parents — incoming edges

The tuple of Value objects that produced this one. The backward pass walks parents in topological order so every child's gradient is finalised before we ask it to push to its parents.

EXAMPLE
For y = u**2, y._parents == (u,) — exactly one parent.
6self._op — debug label

Pure documentation — '+', '*', '**2'. Real engines like PyTorch store an actual gradient-function object here (AddBackward, MulBackward, PowBackward) which knows how to differentiate the op.

7self._backward — the local rule

A closure that, when called, reads out.grad (already filled in for this node) and adds the right contribution into each parent's .grad. Every primitive overrides this. Leaves keep the no-op lambda.

EXAMPLE
For out = a + b, _backward does a.grad += out.grad; b.grad += out.grad.
9__add__ — overload Python's + operator

Called when you write a + b for two Values. We create a fresh Value whose .data is a.data + b.data and whose parents are (self, other). The local backward rule comes next.

EXAMPLE
Value(2.0) + Value(3.0)  →  Value(5.0, parents=(...,...), op='+')
10Forward: compute the sum

Builds the output node. self.data + other.data is plain Python arithmetic — no magic. The result Value remembers its parents so backward can find them.

EXAMPLE
out.data = self.data + other.data = 6.0 + 1.0 = 7.0 in the toy.
11Define the local backward closure

A nested function that captures self, other, and out. It will run during backward() once out.grad has been filled in by everything downstream of it.

12Push gradient to self

d(self + other)/d(self) = 1, so self.grad gets one full copy of out.grad. We use += because self might be a parent of many other nodes; their contributions must sum.

EXAMPLE
If out.grad = 14, this line does self.grad += 1.0 * 14 → +=14.
13Push gradient to other

Symmetric to the previous line. d(self + other)/d(other) = 1, so other.grad picks up out.grad as well.

EXAMPLE
In the toy this is exactly how b ends up with .grad = 14.
14Attach the closure

Stash _backward on the output node so backward() can call it later. The closure has already captured self and other, so calling it needs no arguments.

15Return the new node

out is itself a Value, so it can be added, multiplied, or squared again. This is how the graph grows: every primitive returns a new node whose parents wire it into the DAG.

17__mul__ — overload *

Same recipe as addition but for multiplication. We compute self.data * other.data forward, then derive each parent's gradient from the product rule.

EXAMPLE
W * x with W.data=2, x.data=3 gives a Value(6.0, parents=(W,x), op='*').
18Forward: compute the product

Just float multiplication. In the toy: 2.0 * 3.0 = 6.0. That number flows downstream into the addition with b.

19Define the product-rule backward

We need d(a*b)/da = b and d(a*b)/db = a. The closure will use the cached .data of each parent — that is why we needed the forward pass first.

20Push grad to self — by other.data

d(a*b)/da = b, so we multiply out.grad by other.data before adding. Crucially we use other.data (a number captured during forward), not other.grad — gradients chain on values, not on other gradients.

EXAMPLE
If W=2, x=3, and out.grad = 14, then x.grad += 2 * 14 = 28.
21Push grad to other — by self.data

Mirror image of the previous line. d(a*b)/db = a, so other.grad picks up self.data * out.grad.

EXAMPLE
Same out.grad=14: W.grad += 3 * 14 = 42 — exactly our hand-derived answer.
22Attach closure to the output

Bind the local rule onto out so the engine can find it during backward().

23Return the product node

We hand back a Value that knows both its forward number and the recipe for splitting its incoming gradient between its two parents.

25pow2 — the squaring primitive

We only need x → x² for our toy, so we expose it as a method instead of a general __pow__. This isolates the rule d(x²)/dx = 2x cleanly.

EXAMPLE
Value(7.0).pow2() → Value(49.0, parents=(<7.0 node>,), op='**2').
26Forward: square the data

Plain Python: self.data ** 2. In the toy, self.data = 7.0 here, so out.data = 49.0.

27Define power-rule backward

Squaring has a single parent (self), so the closure only needs to handle one accumulation.

28Power-rule line

d(x²)/dx = 2x. We multiply the upstream gradient out.grad by 2*self.data. In the toy, self.data = 7 and out.grad starts as 1 (because out IS y), so this writes 2*7*1 = 14 into v.grad — exactly dL/dv from the hand calculation.

EXAMPLE
v = u + b = 7;  y = v² = 49;  dL/dv = 2v · dL/dy = 14 · 1 = 14.
29Bind the closure

Attach the rule. Same pattern as __add__ and __mul__.

30Return the squared node

Hand back the new Value. Because it inherits from Value, you could keep going — y.pow2().pow2() would just keep extending the graph.

32backward — the engine entry point

This is the method you call on the loss. It runs three phases: build a topological order, seed the loss gradient with 1, then walk the order in reverse calling each node's _backward closure.

34Topological-order setup

order will hold every Value in the graph such that parents come before children. seen is a set used as a 'visited' marker so each node is processed exactly once even if it has many children.

35DFS visitor — function definition

A recursive helper that walks the DAG from the loss upward through ._parents, adding each node to order only after its parents have already been added. This is the standard topological sort via DFS post-order.

36Skip already-visited nodes

Without this guard, a node shared by multiple children would be appended multiple times and we'd call its _backward more than once — double-counting its gradient.

37Mark visited

Add v to seen before recursing so cycles (none in a DAG, but still) cannot loop.

38Recurse into parents

For every parent p of v, visit p first. This guarantees that when we eventually append v, every node it depends on is already in order.

39Recursive call

Standard DFS recursion. Python's default recursion limit (~1000) is fine for our toy; PyTorch implements the same idea iteratively for graphs with millions of nodes.

40Post-order append

Crucially, we append AFTER the recursive call returns. That gives us a list where index 0 is the deepest input and the last element is the node we called backward() on.

41Kick off the visit from the root

We call visit(self) — self is the loss. After this line, order is the full topological ordering with the loss at the end.

43Seed dL/dL = 1

The loss differentiates against itself to 1. Setting self.grad = 1.0 is the boundary condition that makes the whole chain rule start; every other gradient is derived from this.

EXAMPLE
If you want to compute gradient of (loss * 2), you would seed with 2 instead.
45Walk in reverse

Iterate from the loss back toward the inputs. By the time we reach a node, every one of its children has already pushed their contribution into its .grad, so the local backward rule sees a fully-formed out.grad.

46Apply the local rule

Calling v._backward() reads v.grad and updates each parent's .grad. Loops over the graph automatically because Python closures captured self / other / out at op creation time.

EXAMPLE
For our toy the call order is: y._backward (squaring) → (W*x+b)._backward (add) → (W*x)._backward (mul). Three calls, gradients done.
50Create the input scalars

Three leaf Values. They have empty parents and the default no-op _backward. Their .grad starts at 0.0 and will be filled in by backward().

EXAMPLE
W.data=2.0, x.data=3.0, b.data=1.0.
51x — the data input

Same construction as W. In a real network this would be activations from the previous layer, not a trainable parameter — but the engine treats them identically.

52b — the bias

Another leaf scalar. Together with W and x, these are the only three nodes whose .grad we ultimately care about.

54Build the forward graph

Reading left to right: W*x calls __mul__ (creates node u=6 with parents W, x), then + b calls __add__ (creates v=7 with parents u, b), then .pow2() creates y=49 with parent v. Five nodes total, wired by the parent links.

EXAMPLE
Graph: W(2)  x(3) ──*── u(6)  b(1) ──+── v(7) ──**2── y(49)
55Trigger reverse-mode autodiff

Topo-sort the graph (order will be [W, x, u, b, v, y]), seed y.grad = 1, then call each node's _backward in reverse. After this line every leaf has a meaningful .grad.

57Forward value

Prints 49.0 — y = (2*3 + 1)² = 7² = 49.

58Gradient w.r.t. W

Prints 42.0. Hand check: dy/dW = 2(Wx+b)·x = 2·7·3 = 42. This is the number the optimizer would use to nudge W in the next training step.

59Gradient w.r.t. x

Prints 28.0. Hand check: dy/dx = 2(Wx+b)·W = 2·7·2 = 28. In a real network we usually do not update x (it is data), but downstream layers need this gradient to keep propagating.

60Gradient w.r.t. b

Prints 14.0. Hand check: dy/db = 2(Wx+b)·1 = 14. Matches the manual walkthrough exactly.

14 lines without explanation
1class Value:
2    def __init__(self, data, parents=(), op=""):
3        self.data = data            # the actual number
4        self.grad = 0.0             # gradient dL/dself, filled in by backward()
5        self._parents = parents     # tuple of Value nodes that produced this one
6        self._op = op               # debugging label: '+', '*', '**2', ...
7        self._backward = lambda: None  # local rule that pushes grad to parents
8
9    def __add__(self, other):
10        out = Value(self.data + other.data, (self, other), "+")
11        def _backward():
12            self.grad  += 1.0 * out.grad   # d(self+other)/d(self)  = 1
13            other.grad += 1.0 * out.grad   # d(self+other)/d(other) = 1
14        out._backward = _backward
15        return out
16
17    def __mul__(self, other):
18        out = Value(self.data * other.data, (self, other), "*")
19        def _backward():
20            self.grad  += other.data * out.grad  # d(a*b)/da = b
21            other.grad += self.data  * out.grad  # d(a*b)/db = a
22        out._backward = _backward
23        return out
24
25    def pow2(self):
26        out = Value(self.data ** 2, (self,), "**2")
27        def _backward():
28            self.grad += 2.0 * self.data * out.grad   # d(x^2)/dx = 2x
29        out._backward = _backward
30        return out
31
32    def backward(self):
33        # 1) topological order: parents before children
34        order, seen = [], set()
35        def visit(v):
36            if v in seen: return
37            seen.add(v)
38            for p in v._parents:
39                visit(p)
40            order.append(v)
41        visit(self)
42
43        # 2) seed: dL/dL = 1
44        self.grad = 1.0
45
46        # 3) walk backward and apply each local rule
47        for v in reversed(order):
48            v._backward()
49
50
51# Run the exact toy from the walkthrough: y = (W*x + b)^2, with W=2, x=3, b=1.
52W = Value(2.0)
53x = Value(3.0)
54b = Value(1.0)
55
56y = (W * x + b).pow2()      # forward pass builds the graph as it goes
57y.backward()                # reverse pass fills in every .grad
58
59print(f"y      = {y.data}")    # 49.0
60print(f"dy/dW  = {W.grad}")    # 42.0
61print(f"dy/dx  = {x.grad}")    # 28.0
62print(f"dy/db  = {b.grad}")    # 14.0

Three points worth pausing on. First, the engine never stores any formula — only numbers and a list of parent pointers. The chain rule emerges from the order in which we call closures, not from algebra. Second, self.grad uses += rather than =: a node can feed many children, and each child must contribute its term to the multivariable chain rule sum (§1.4.4). Third, the topological sort is non-negotiable — call the closures in the wrong order and some children will fire before their gradient has been finalised, silently losing terms.


1.4.10 PyTorch: The Same Math at Industrial Scale

PyTorch's autograd is, structurally, our Value class with three upgrades: each node holds an N-dimensional tensor instead of one float; each _backward is a hand-tuned C++/CUDA kernel that handles the vector-Jacobian product without ever materialising the Jacobian; and the topological sort runs iteratively so it scales to millions of nodes. The user-facing API is deliberately almost identical to the toy.

Same toy gradient — now in PyTorch
🐍autograd_toy.py
1Import PyTorch

The torch module is the single entry point to tensors, autograd, GPU kernels, and the optimizer suite. Every massive-model training script begins here.

EXAMPLE
torch.__version__  →  '2.x.y'
3Comment — set up the leaves

Pure documentation. Below, we recreate the toy from §1.4.8 with PyTorch instead of our hand-rolled Value engine. The math is identical; the difference is that PyTorch will dispatch each op to optimised C++/CUDA kernels.

4W — a trainable scalar

torch.tensor(2.0) creates a 0-d tensor (a scalar). requires_grad=True tells autograd: 'remember every op that touches me and keep a slot W.grad for my partial derivative.' This is exactly the role of a leaf in our Value class.

EXAMPLE
W.shape == torch.Size([]);  W.is_leaf == True;  W.grad is None  (until backward).
5x — the input scalar

Same construction. In real training x would be a data tensor (a batch of images, a sequence of token ids embedded into vectors). We still set requires_grad=True here so we can read x.grad below; for real data you'd usually leave it False.

EXAMPLE
x.dtype == torch.float32  by default — note the implicit precision choice.
6b — the bias scalar

Third leaf. The three leaves W, x, b are the only tensors that will end up holding .grad after backward(). Intermediate tensors get gradients too, but PyTorch frees them right after to save memory.

8Comment — start the forward pass

Documentation. The next line looks like ordinary Python arithmetic, but every operator is overloaded: each one records an entry in the autograd tape (the dynamic computation graph PyTorch builds on the fly).

9Build the graph with one expression

Three ops happen, in order: W*x calls __mul__ (creates an intermediate tensor with grad_fn=MulBackward0); + b calls __add__ (AddBackward0); ** 2 calls __pow__ (PowBackward0). y is the root. In our Value engine these are the same three primitives, just slower.

EXAMPLE
Internally PyTorch stores a function object on each output: y.grad_fn → PowBackward0; y.grad_fn.next_functions points back to AddBackward0; and so on.
11Read the forward value

.item() pulls a 0-d tensor out as a Python float. We use it for printing because a tensor's repr is noisier than a plain number.

EXAMPLE
y.item() == 49.0  ⇔  y.data == tensor(49.).
12Inspect grad_fn

grad_fn is the autograd hook on a non-leaf tensor — it points to the backward function for the op that produced it. Leaves have grad_fn=None; here y was produced by **2, so it shows PowBackward0. This is the exact mechanism the engine uses to walk backward.

EXAMPLE
y.grad_fn.next_functions[0][0] → <AddBackward0> (the previous op).
14Comment — kick off the backward pass

Documentation. The next call replaces our hand-written reverse pass with a single library function that has been hardened by years of production use.

15y.backward() — the one-line autodiff call

Autograd seeds y.grad = 1 (because y is scalar — for non-scalar tensors you must pass an explicit gradient), then walks grad_fn pointers from y back toward the leaves, calling each backward op. Each leaf's .grad slot is populated. The non-leaf intermediates are then freed to release memory.

EXAMPLE
After this line: W.grad = tensor(42.), x.grad = tensor(28.), b.grad = tensor(14.).
17W.grad — the same 42 we computed by hand

PyTorch wrote the answer into the .grad attribute of the leaf tensor. This is the number an optimizer such as SGD or AdamW reads when it performs W ← W − lr·W.grad.

EXAMPLE
lr=0.01  ⇒  next W ≈ 2.0 − 0.01·42 = 1.58.
18x.grad — 28

The data gradient. Inside a deep network, this is exactly the signal that flows from one layer back into the previous layer, allowing the chain rule to keep going across hundreds of transformer blocks.

19b.grad — 14

Matches the hand calculation. Notice b's gradient is the smallest — it sits closest to the loss in the chain rule product, so it carries fewer factors.

21Comment — accumulation gotcha

Documentation. The next two lines flag the single most common autograd footgun: gradients add into .grad, they do not overwrite.

22Zero W.grad in place

If you call backward() a second time without zeroing, the new gradient is added to the old one, doubling every effective learning step. Optimizers expose .zero_grad() to do this in bulk; here we do it by hand. (The same accumulation behaviour is what lets us implement gradient accumulation across micro-batches in §1.4.11.)

EXAMPLE
Without zero_(), a second y.backward() would leave W.grad = 84.0 — exactly double.
23Zero b.grad too

Same idea. In real training code you write opt.zero_grad() once and it handles every parameter in every module — but it is the same .zero_() under the hood.

6 lines without explanation
1import torch
2
3# 1) Declare the leaves. requires_grad=True turns on autograd tracking.
4W = torch.tensor(2.0, requires_grad=True)
5x = torch.tensor(3.0, requires_grad=True)
6b = torch.tensor(1.0, requires_grad=True)
7
8# 2) Forward pass — PyTorch builds the same DAG behind the scenes.
9y = (W * x + b) ** 2          # y is a non-leaf tensor with grad_fn=PowBackward0
10
11print(y.item())               # 49.0 — same forward value as our toy engine
12print(y.grad_fn)              # <PowBackward0 object at 0x...>
13
14# 3) Reverse pass — one call walks the graph and fills every leaf .grad.
15y.backward()
16
17print(W.grad.item())          # 42.0
18print(x.grad.item())          # 28.0
19print(b.grad.item())          # 14.0
20
21# 4) Gradients accumulate by default — clear them before the next step.
22W.grad.zero_()
23b.grad.zero_()

The output is the same three numbers you derived by hand: 42, 28, 14. That is the entire promise of automatic differentiation — the mathematics does not change as the model scales from three nodes to three trillion. Only the kernels, the memory plan, and the parallelisation strategy do.

How to think about grad_fn: in our toy we stored _backward as a Python closure on every node. PyTorch stores a C++ Node object called the grad_fn — same idea, but it can be moved between processes for distributed training, serialised for checkpointing, and inspected for debugging via torch.autograd.gradcheck.

1.4.11 At Massive Scale: Why Autodiff Is the Real Hero

Now scale our toy to a real frontier-model training run and watch what survives. The mathematics — the chain rule, the forward-then-reverse pass, the gradient accumulation — is identical. Everything else has to be redesigned.

Memory: the silent killer

Reverse mode is fast because every intermediate is reused exactly once on the way back. But that means every intermediate must be stored until then. For a 70B-parameter transformer training on sequences of 4096 tokens with batch 1 in bfloat16, the activations alone weigh on the order of 200–400 GB per data-parallel rank — more than the parameters themselves, sometimes by an order of magnitude.

Memory categoryScales withRough share at 70B/4k context
Parameters (θ)N140 GB in bf16, 70 GB in fp8
Gradients (∂L/∂θ)NSame shape as parameters: 140 GB / 70 GB
Optimizer state (Adam: m, v)2NOften the biggest single slice — 280 GB+ in fp32
Activations (autodiff bookkeeping)batch × seq × hidden × layersDominates for long contexts — 200–600 GB

Gradient checkpointing — trade compute for memory

The classical autodiff bargain is "keep every intermediate". Gradient checkpointing breaks that bargain on purpose: drop most activations during forward, and when backward needs one, recompute it from the nearest checkpoint. You roughly double the forward FLOPs but cut activation memory by an order of magnitude. The chain rule is unchanged — you just rebuild the missing intermediates on demand.

Gradient accumulation — virtually larger batches

Look at the += in our toy: it is exactly what lets you train with batch size 4096 on a GPU that fits only batch 4. Run a micro-batch, call backward(), do not zero.grad, run the next micro-batch, repeat 1024 times, then step the optimizer. The accumulated gradient is mathematically identical to the full-batch gradient (modulo non-linearities like batch norm). The autodiff primitive that makes this safe is +=, not the optimizer.

Mixed precision — bf16/fp16 forward, fp32 gradients where it counts

Forward intermediates are stored in bf16 (or fp8 on the newest hardware) to save memory and matmul cost. Gradients are also bf16, but the optimizer keeps a master fp32 copy of the weights and the Adam moments to avoid catastrophic loss-of-precision at update time. The autograd graph routes a cast op between every fp32 leaf and its bf16 use, so the chain rule walks back through the casts like any other op.

FSDP / ZeRO — shard the parameters across ranks

Even after checkpointing and mixed precision, the parameters and optimizer state alone outgrow any single GPU around 7B. ZeRO-3 and PyTorch FSDP shard θ\theta, the gradients, and the optimizer state across all data-parallel ranks. Before each layer's forward, the rank that owns the shard all-gathers its weights; afterwards it discards them. Backward does the same in reverse, but additionally reduce-scatters the gradients so only the owning rank holds L/θi\partial L / \partial \theta_ifor its shard. From the autograd engine's point of view nothing has changed — the gather/scatter are just more ops in the graph.

The deep lesson. Every memory and parallelism trick in massive-model training is a wrapper around the chain rule. The graph stays the same. We just teach the engine where to spill, recompute, shard, or cast as it walks the graph.

1.4.12 Engineering Reality and Pitfalls

The autograd engine is one of the most battle-tested pieces of code in PyTorch, but it can still be made to lie if you push it the wrong way. The list below is the high-frequency failure modes in real training runs.

  • Forgetting zero_grad(). Because gradients accumulate, missing a single zero turns every subsequent step into a runaway. Symptom: loss explodes after step 1 with no obvious cause.
  • In-place ops on tensors that backward needs. Writing x.add_(1) mutates the activation the backward pass expected to see. PyTorch raises a runtime error in many cases but not all — when it does not, you silently compute the wrong gradient.
  • Using .detach() without meaning to. Anything you copy via .detach(), convert to numpy, or print as a Python float is cut off from the graph. Gradients that should have flowed through it become zero. Useful when you want it (e.g. teacher targets); catastrophic when you do not.
  • Non-differentiable ops. argmax, discrete sampling, torch.where with a hard threshold: each returns a tensor with no gradient. Either reroute through a differentiable surrogate (Gumbel-softmax, straight-through estimator) or accept that no learning signal flows there.
  • NaN gradients. The chain rule multiplies. One rogue infinity in a single intermediate corrupts every leaf upstream. Common causes: log(0), 1/0, sqrt of a negative value, fp16 overflow. Mixed precision adds a GradScaler exactly to catch this.
  • Forgetting torch.no_grad() at inference. Without it, autograd still records the graph during eval — wasting memory and slowing inference. Production inference always wraps forward in with torch.no_grad(): or its successor torch.inference_mode().
  • Higher-order gradients are expensive. Gradients of gradients (for meta-learning, second-order optimisers, or influence functions) require create_graph=True, which keeps the backward graph itself differentiable. That stacks another autodiff pass on top — memory and compute double.
Always gradient-check new ops. If you write a custom autograd Function with your own forward and backward, run torch.autograd.gradcheck before trusting it. It compares your analytic backward against finite differences at high precision and is the single best defense against subtle bugs.

Summary

Modern training has one inner loop — forward, loss, backward, step — and reverse-mode automatic differentiation is the only thing that makes the backward step cheap enough to repeat ten million times. We derived it from first principles:

  1. A scalar loss over 101110^{11} parameters is hopeless to differentiate by formula or by perturbation; reverse mode does the full gradient in one extra pass.
  2. A network is a DAG of primitive ops; the forward pass evaluates it, the backward pass walks the same DAG in reverse, applying the chain rule locally at each node.
  3. The whole engine fits in sixty lines of Python: a Value class, three operator overloads with += accumulation, a topological sort, and a seed of 1 at the loss. PyTorch is the same algorithm, just with fast kernels and a million times more bookkeeping.
  4. Everything else in a frontier-model training stack — gradient checkpointing, mixed precision, gradient accumulation, ZeRO/FSDP — is a memory or parallelism trick wrapped around this single primitive.

From here on in the book you can treat loss.backward() as one line of code that you fully understand. The next section (§1.5) takes the gradients we just learned to compute and asks the next question: given the gradient, how do we actually move? That is the realm of gradient descent, momentum, Adam, AdamW, and the optimisation landscape.

Loading comments...