Chapter 8
18 min read
Section 50 of 178

Computational Graphs

Backpropagation from Scratch

Learning Objectives

By the end of this section, you will be able to:

  1. Understand computational graphs as the fundamental data structure for representing neural network computations
  2. Visualize forward and backward passes as information flowing through the graph in opposite directions
  3. Apply the chain rule on graphs to systematically compute gradients for any computation
  4. Explain how autograd systems work by recording operations and replaying them in reverse
  5. Analyze gradient flow patterns including addition, multiplication, and branching operations
  6. Understand the memory vs computation trade-offs in modern deep learning frameworks
Why This Matters: Computational graphs are the secret weapon of modern deep learning. They transform the complex problem of computing gradients through millions of parameters into a simple, systematic algorithm. Every major deep learning framework—PyTorch, TensorFlow, JAX—is built on this foundation.

The Big Picture

Imagine you need to find the derivative of a complex function involving dozens of operations and millions of parameters. Doing this by hand would be impossibly tedious and error-prone. Computational graphs provide an elegant solution: represent the computation as a directed graph, then compute gradients automatically using the chain rule.

The Origin Story

The idea of using graphs to represent computations dates back to the 1960s, but the breakthrough for neural networks came in the 1980s with the development of backpropagation. Researchers realized that by structuring computations as directed acyclic graphs (DAGs), they could:

  1. Compute outputs efficiently in a single forward pass
  2. Compute all gradients efficiently in a single backward pass
  3. Reuse intermediate values, avoiding redundant computation

This insight is what makes training neural networks with billions of parameters feasible. Without computational graphs and efficient backpropagation, modern AI would not exist.

Two Key Insights

InsightExplanationConsequence
ModularityComplex functions can be decomposed into simple operationsEach operation only needs a local gradient
Chain RuleGradients flow through the graph via multiplicationGlobal gradients computed from local ones

What Is a Computational Graph?

A computational graph is a directed graph where:

  • Nodes represent values (inputs, intermediate results, outputs)
  • Edges represent operations or data flow between values

A Simple Example

Consider the function L=(wx+by)2L = (wx + b - y)^2, which is the squared error loss for linear regression. We can break this into elementary operations:

z1=wx(multiply)z2=z1+b(add)z3=z2y(subtract)L=z32(square)\begin{aligned} z_1 &= w \cdot x & \text{(multiply)} \\ z_2 &= z_1 + b & \text{(add)} \\ z_3 &= z_2 - y & \text{(subtract)} \\ L &= z_3^2 & \text{(square)} \end{aligned}

Each operation becomes a node in our graph, with edges showing the flow of data from inputs to outputs. This decomposition is key: each elementary operation has a simple, known gradient.

Why Graphs?

Using a graph structure provides several advantages:

  • Automatic differentiation: Gradients computed systematically, not derived by hand
  • Efficiency: Intermediate values cached during forward pass, reused in backward pass
  • Generality: Works for any computation that can be expressed as a DAG
  • Parallelization: Independent operations can be computed in parallel

Anatomy of a Computational Graph

Let's examine the components of a computational graph in detail:

Node Types

Node TypeDescriptionExample
InputExternal values provided to the computationx, w, b, y (training data and parameters)
OperationComputes a value from its inputs+, ×, σ, ReLU, matmul
OutputFinal result of the computationLoss L, prediction ŷ

Edge Information

Each edge in the graph carries two pieces of information:

  1. Forward direction: The value flowing from one node to the next
  2. Backward direction: The local gradient output/input\partial \text{output} / \partial \text{input}

The Gradient Function

Every operation in the graph has an associated gradient function (or grad_fn in PyTorch terminology). This function computes how the operation's output changes with respect to each input.

For z=f(x,y):zx,zy\text{For } z = f(x, y): \quad \frac{\partial z}{\partial x}, \frac{\partial z}{\partial y}

Local vs Global Gradients

Local gradients are the derivatives of a single operation with respect to its inputs. Global gradients are the derivatives of the final output with respect to any node in the graph. Backpropagation computes global gradients by chaining local gradients.

Interactive: Forward and Backward Pass

Explore the computational graph for L=(wx+bt)2L = (wx + b - t)^2. In the forward pass, values flow from inputs to outputs. In the backward pass, gradients flow from outputs back to inputs.

Interactive Computational Graph

Visualize forward pass (computing values) and backward pass (computing gradients)

wx2·diffx2.00w3.00b1.00z6.00(×)y7.00(+)t5.00y-t2.00(-)L4.00(²)InputOperationOutput
1

Input values

Start with inputs: x=2, w=3, b=1, target=5

Step 1 / 5

Computation Summary

Forward:
L = (wx + b - t)²
= 4.00
∂L/∂w:
= 2(y-t) · x
= 8.00
∂L/∂b:
= 2(y-t)
= 4.00
∂L/∂x:
= 2(y-t) · w
= 12.00

Quick Check

In the backward pass, what is the gradient ∂L/∂z at the node z = wx when ∂L/∂y = 2(y-t) and y = z + b?


The Chain Rule on Graphs

The chain rule is the mathematical foundation of backpropagation. On computational graphs, it takes a particularly elegant form.

Single Path

For a simple chain xfgLx \to f \to g \to L, the chain rule gives:

Lx=Lggffx\frac{\partial L}{\partial x} = \frac{\partial L}{\partial g} \cdot \frac{\partial g}{\partial f} \cdot \frac{\partial f}{\partial x}

This is just multiplying local gradients along the path from LL to xx.

Multiple Paths

When a node has multiple outgoing edges (its value is used by multiple downstream operations), we sum the gradients from all paths:

Lx=all pathsedges on path(local gradient)\frac{\partial L}{\partial x} = \sum_{\text{all paths}} \prod_{\text{edges on path}} (\text{local gradient})

This is the multivariate chain rule. If xx influences LL through multiple intermediate nodes y1,y2,y_1, y_2, \ldots:

Lx=Ly1y1x+Ly2y2x+\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y_1} \cdot \frac{\partial y_1}{\partial x} + \frac{\partial L}{\partial y_2} \cdot \frac{\partial y_2}{\partial x} + \cdots

The Backward Pass Algorithm

The backward pass computes gradients efficiently by:

  1. Starting at the output with L/L=1\partial L / \partial L = 1
  2. Processing nodes in reverse topological order
  3. At each node, multiplying incoming gradient by local gradient
  4. Accumulating gradients when paths merge (multiple outgoing edges)

Efficiency

The backward pass visits each node exactly once, making it O(n)O(n) where nn is the number of operations. This is the same cost as the forward pass!

Interactive: Chain Rule Exploration

Explore how the chain rule applies to different computational graph structures. Click on different examples to see single-path chains, multi-path graphs, and neural network layers.

Chain Rule on Computational Graphs

Explore how the chain rule propagates gradients through different graph structures

f(g(x)) - single path, direct chain rule application

xg(x)f(g)

Applying the Chain Rule:

∂f/∂x = (∂f/∂g) × (∂g/∂x) = cos(g) × 2x = 2x·cos(x²)

Key Insight: Local Gradients

Each edge carries a local gradient (∂output/∂input for that operation). Backpropagation multiplies these local gradients along each path from output to input.

Quick Check

If x feeds into both a = x² and b = x³, and the output is y = a + b, what is ∂y/∂x?


Key Gradient Flow Patterns

Understanding how gradients flow through common operations is essential for debugging neural networks and designing architectures.

Addition Gate

For z=x+yz = x + y:

zx=1,zy=1\frac{\partial z}{\partial x} = 1, \quad \frac{\partial z}{\partial y} = 1

Pattern: Addition acts as a gradient distributor. The incoming gradient is passed unchanged to both inputs.

Multiplication Gate

For z=xyz = x \cdot y:

zx=y,zy=x\frac{\partial z}{\partial x} = y, \quad \frac{\partial z}{\partial y} = x

Pattern: Multiplication acts as a gradient swapper. Each input receives the gradient multiplied by the other input's value.

Max Gate

For z=max(x,y)z = \max(x, y):

zx={1if x>y0otherwise\frac{\partial z}{\partial x} = \begin{cases} 1 & \text{if } x > y \\ 0 & \text{otherwise} \end{cases}

Pattern: Max acts as a gradient router. The full gradient flows only to the larger input; the smaller input receives zero gradient.

Copy/Fan-out

When a value is used multiple times:

Lx=Lz1+Lz2+\frac{\partial L}{\partial x} = \frac{\partial L}{\partial z_1} + \frac{\partial L}{\partial z_2} + \cdots

Pattern: Copying acts as a gradient accumulator. Gradients from all uses are summed.

OperationForwardBackward Pattern
Add: z = x + yz = x + y∂z/∂x = ∂z/∂y = 1 (distribute)
Multiply: z = xyz = x × y∂z/∂x = y, ∂z/∂y = x (swap)
Max: z = max(x,y)z = max(x,y)Gradient to winner only (route)
Copy: z₁=z₂=xCopy xSum all incoming gradients (accumulate)
ReLU: z = max(0,x)z = max(0,x)Pass if x>0, else 0

Dead ReLU Problem

Since ReLU passes gradient only when input is positive, neurons that always produce negative values ("dead neurons") stop learning entirely. This is one reason for variants like Leaky ReLU.

See Chapter 9 Section 7

For practical code to detect dead neurons and debug training issues, see Chapter 9 Section 7: Debugging Neural Networks.

Autograd: Automatic Differentiation

Modern deep learning frameworks implement automatic differentiation (autograd) to compute gradients without manual derivation. Understanding how this works demystifies framework internals.

The Two Phases

  1. Forward Pass (Tape Recording): As operations execute, the framework records each operation, its inputs, and its output onto a "tape" (also called a "trace" or "computational graph").
  2. Backward Pass (Tape Playback): Starting from the output, the framework walks the tape in reverse, applying the chain rule at each step to compute gradients.

What Gets Stored?

For each operation, the tape stores:

  • The operation type (add, multiply, matmul, etc.)
  • References to input tensors
  • The output tensor
  • A gradient function (grad_fn) that knows how to compute local gradients
  • Any intermediate values needed for the backward pass

Define-by-Run vs Define-and-Run

ParadigmFrameworkHow It Works
Define-by-RunPyTorch, JAXGraph built dynamically during execution
Define-and-RunTensorFlow 1.xGraph defined first, then executed

PyTorch's define-by-run approach means the graph is constructed fresh for each forward pass, enabling dynamic computation (different graphs for different inputs) and easier debugging.


Interactive: How Autograd Works

Step through PyTorch code and see how the computational tape is built during the forward pass and traversed during the backward pass.

How Autograd Works: The Tape

Understand how PyTorch's automatic differentiation records operations and computes gradients

PyTorch Code
# Initialize Inputs
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(3.0, requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)

Initialize Inputs

Create tensors with requires_grad=True to track gradients

When requires_grad=True, PyTorch will record operations on these tensors.
1/6
Computational Tape
MULt1 = 6
inputs: [x, w]
ADDt2 = 7
inputs: [t1, b]
SQUAREy = 49
inputs: [t2]

1. Forward Pass

Compute output values while recording operations to the tape

2. Build Graph

Store local gradients (grad_fn) at each node for backward pass

3. Backward Pass

Walk tape in reverse, multiply local gradients via chain rule

Quick Check

In PyTorch, when does the computational graph get built?


Implementation in PyTorch

Let's see how computational graphs work in PyTorch, from basic tensor operations to understanding the grad_fn chain.

Understanding Computational Graphs in PyTorch
🐍computational_graph.py
4requires_grad=True

This flag tells PyTorch to track all operations on this tensor and build a computational graph. Only tensors with requires_grad=True will have gradients computed.

9Forward Pass

Each operation creates a new tensor with a grad_fn that remembers how to compute its gradient. The graph is built implicitly as operations execute.

EXAMPLE
z.grad_fn = <MulBackward0>, y.grad_fn = <AddBackward0>
13Examining grad_fn

Every tensor created by an operation has a grad_fn attribute. This is a function that computes the local gradient. The next_functions point to the parent operations.

17Backward Pass

backward() traverses the graph from loss back to all leaf tensors, applying the chain rule. Gradients are accumulated in the .grad attribute of each leaf tensor.

20Reading Gradients

After backward(), each tensor with requires_grad=True has its .grad attribute populated with ∂loss/∂tensor. These are the gradients used for optimization.

EXAMPLE
∂L/∂w = 2y × x = 2(7)(2) = 28
18 lines without explanation
1import torch
2
3# Create input tensors with gradient tracking
4x = torch.tensor(2.0, requires_grad=True)
5w = torch.tensor(3.0, requires_grad=True)
6b = torch.tensor(1.0, requires_grad=True)
7
8# Forward pass - operations are recorded
9z = w * x        # z = 6.0
10y = z + b        # y = 7.0
11loss = y ** 2    # loss = 49.0
12
13# Examine the computational graph
14print(f"loss.grad_fn: {loss.grad_fn}")
15print(f"  -> next: {loss.grad_fn.next_functions}")
16
17# Backward pass - traverse graph in reverse
18loss.backward()
19
20# Gradients are now computed
21print(f"\n∂L/∂w = {w.grad}")  # 2*y*x = 2*7*2 = 28
22print(f"∂L/∂b = {b.grad}")    # 2*y*1 = 2*7 = 14
23print(f"∂L/∂x = {x.grad}")    # 2*y*w = 2*7*3 = 42

Gradient Accumulation

An important detail: gradients in PyTorch are accumulated, not replaced. If you call backward() multiple times without zeroing gradients, they add up:

Gradient Accumulation Behavior
🐍gradient_accumulation.py
7First Gradient

∂(x²)/∂x = 2x = 2(3) = 6.0. This is stored in x.grad.

12Accumulated Gradient

Without zeroing, the new gradient (27.0 from x³) is ADDED to the old one (6.0), giving 33.0. This is often a bug!

16Zero Gradients

Use .zero_() or optimizer.zero_grad() before each backward pass to clear accumulated gradients. This is essential in training loops.

16 lines without explanation
1import torch
2
3x = torch.tensor(3.0, requires_grad=True)
4
5# First forward-backward
6y1 = x ** 2
7y1.backward()
8print(f"After first backward: x.grad = {x.grad}")  # 6.0
9
10# Second forward-backward (without zeroing)
11y2 = x ** 3
12y2.backward()
13print(f"After second backward: x.grad = {x.grad}")  # 6.0 + 27.0 = 33.0
14
15# The right way: zero gradients before each backward
16x.grad.zero_()
17y3 = x ** 2
18y3.backward()
19print(f"After zeroing and backward: x.grad = {x.grad}")  # 6.0

Training Loop Pattern

Always zero gradients before backward(): optimizer.zero_grad() then loss.backward() then optimizer.step(). This pattern prevents gradient accumulation bugs.

Common Operations and Their Gradients

Here are the gradient rules for common neural network operations. Understanding these helps debug gradient flow issues.

OperationForwardLocal Gradient
Addz = x + y∂z/∂x = 1, ∂z/∂y = 1
Subtractz = x - y∂z/∂x = 1, ∂z/∂y = -1
Multiplyz = xy∂z/∂x = y, ∂z/∂y = x
Dividez = x/y∂z/∂x = 1/y, ∂z/∂y = -x/y²
Powerz = xⁿ∂z/∂x = nxⁿ⁻¹
Expz = eˣ∂z/∂x = eˣ = z
Logz = ln(x)∂z/∂x = 1/x
Sigmoidz = σ(x)∂z/∂x = σ(x)(1-σ(x)) = z(1-z)
Tanhz = tanh(x)∂z/∂x = 1 - tanh²(x) = 1 - z²
ReLUz = max(0,x)∂z/∂x = 1 if x>0 else 0
Softmaxzᵢ = eˣⁱ/ΣeˣʲComplex (Jacobian matrix)

Matrix Operations

For matrix operations like Y=XWY = XW, the gradients involve transposing matrices: L/X=(L/Y)W\partial L/\partial X = (\partial L/\partial Y) W^\top and L/W=X(L/Y)\partial L/\partial W = X^\top (\partial L/\partial Y).

Memory vs Computation Trade-offs

Building computational graphs comes with memory costs. Understanding these trade-offs is crucial for training large models.

What Gets Stored?

During the forward pass, activations (intermediate values) must be saved for the backward pass. For a network with LL layers:

  • Activations: All intermediate outputs (grows with batch size and model width)
  • Graph structure: Links between operations (relatively small)
  • Intermediate values: Values needed for gradients (e.g., for BatchNorm)

Memory-Saving Techniques

TechniqueHow It WorksTrade-off
Gradient checkpointingDon't save all activations; recompute during backward2× compute, much less memory
Mixed precisionUse FP16 for forward, FP32 for gradients~2× memory savings
Activation compressionCompress stored activationsSome precision loss
torch.no_grad()Don't build graph for inferenceNo gradients available
🐍memory_optimization.py
1# Disable graph building for inference
2with torch.no_grad():
3    output = model(input)  # No grad_fn, saves memory
4
5# Gradient checkpointing for memory-efficient training
6from torch.utils.checkpoint import checkpoint
7output = checkpoint(expensive_layer, input)

Knowledge Check

Test your understanding of computational graphs with this quiz:

Knowledge Check

Question 1 of 5

In a computational graph, what does each node represent?


Summary

Computational graphs are the foundation of modern deep learning, enabling efficient and automatic gradient computation.

Key Concepts

ConceptKey InsightFormula/Example
Computational GraphDAG representing computationNodes = values, Edges = operations
Forward PassCompute outputs, build graphValues flow input → output
Backward PassCompute gradients via chain ruleGradients flow output → input
Local GradientDerivative at single operation∂output/∂input for each edge
Chain RuleMultiply along paths, sum across paths∂L/∂x = Σ(∏ local gradients)
AutogradAutomatic graph construction + traversalrequires_grad=True enables tracking

Key Takeaways

  1. Computational graphs decompose complex functions into simple, differentiable operations
  2. The forward pass computes values while building the graph; backward pass computes gradients
  3. The chain rule on graphs: multiply local gradients along paths, sum across multiple paths
  4. Autograd systems record operations to a "tape" and replay in reverse for gradients
  5. Understanding gradient flow patterns (add distributes, multiply swaps, max routes) helps debugging
  6. Memory-computation trade-offs are important for training large models

Looking Ahead

Now that you understand computational graphs, you're ready to see the full backpropagation algorithm in the next section. We'll derive the exact equations for gradient computation in multi-layer networks and implement backprop from scratch.


Exercises

Conceptual Questions

  1. Draw the computational graph for L=log(σ(wx+b))L = \log(\sigma(wx + b)) where σ\sigma is the sigmoid function. Label all intermediate nodes.
  2. For the graph in question 1, write out the local gradient at each edge. Then compute L/w\partial L / \partial w using the chain rule.
  3. Explain why the gradient at a node that splits into multiple branches must be the sum of gradients from all branches, not the product or average.
  4. If a ReLU neuron always outputs 0 (its input is always negative), what happens to its gradient? Why is this called a "dead neuron"?

Coding Exercises

  1. Manual Gradient Check: Create a simple computation y=3x2+2x+1y = 3x^2 + 2x + 1 in PyTorch. Compute y/x\partial y / \partial x using autograd, then verify by computing the gradient manually.
  2. Graph Exploration: Build a small neural network (2 layers) in PyTorch. After a forward pass, explore the grad_fn chain by following output.grad_fn.next_functions. Print the operation names.
  3. Memory Analysis: Create tensors of increasing size with and without requires_grad=True. Measure memory usage to understand the overhead of gradient tracking.

Solution Hints

  • For Q1: You'll need nodes for wx, wx+b, σ(·), and log(·)
  • For Q3: Think of x contributing to L through multiple independent paths
  • For coding: Use torch.cuda.memory_allocated() or the memory_profiler package

Challenge Exercise

Implement a Mini Autograd: Create a simple Variable class that:

  1. Stores a value and tracks whether it requires gradients
  2. Implements __add__, __mul__, and __pow__ that return new Variables and record the operation
  3. Has a backward() method that computes gradients using the chain rule

Start with just scalar values before attempting vectors/matrices.


In the next section, we'll use computational graphs to derive the backpropagation algorithm for multi-layer neural networks—the core algorithm that makes deep learning possible.