Chapter 4
18 min read
Section 30 of 178

Building Custom Autograd Functions

PyTorch Fundamentals

Learning Objectives

By the end of this section, you will be able to:

  1. Implement custom autograd Functions using torch.autograd.Function
  2. Define forward and backward passes with correct gradient computation
  3. Use ctx.save_for_backward() to store tensors needed for gradients
  4. Handle multiple inputs and outputs in custom functions
  5. Apply the Straight-Through Estimator for non-differentiable operations
  6. Implement higher-order gradients with double backward support
  7. Debug and optimize custom autograd functions
Why This Matters: While PyTorch provides hundreds of differentiable operations out of the box, real-world research and applications often require custom operations - whether for novel activation functions, specialized loss functions, efficiency optimizations, or operations that PyTorch doesn't support natively. Mastering custom autograd functions unlocks the full power of automatic differentiation.

The Big Picture

In the previous section, we learned that PyTorch's autograd system automatically computes gradients by recording operations in a computational graph. Each built-in operation (like torch.add, torch.matmul) is implemented as an autograd.Function with both forward and backward passes defined.

When we create a custom autograd function, we're essentially teaching PyTorch:

  1. How to compute the output given inputs (forward pass)
  2. How to compute gradients given the upstream gradient (backward pass)

The Mathematical Foundation

Consider a function y=f(x)y = f(x) in a larger computation graph. During backpropagation, we receive Ly\frac{\partial L}{\partial y}(the gradient of the loss with respect to our output). We need to computeLx\frac{\partial L}{\partial x}:

Lx=Lyyx=Lyf(x)x\frac{\partial L}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} = \frac{\partial L}{\partial y} \cdot \frac{\partial f(x)}{\partial x}

The term Ly\frac{\partial L}{\partial y} is the upstream gradient (provided to us as grad_output). The term f(x)x\frac{\partial f(x)}{\partial x} is the local gradient (we must compute this). Our job in backward()is to return their product.

When Do You Need Custom Functions?

Use CaseExampleWhy Custom
Novel operationsCustom activation functionsPyTorch doesn't have built-in support
EfficiencyFused operationsCombine multiple ops for speed
Non-differentiable opsQuantization, roundingNeed STE or custom gradients
Numerical stabilityLog-sum-exp variantsControl gradient computation exactly
ResearchNew layer typesImplement paper algorithms precisely

Anatomy of a Custom Function

Every custom autograd function inherits from torch.autograd.Functionand must implement two static methods:

Custom Autograd Function Template
🐍custom_function_template.py
6@staticmethod Decorator

Both forward() and backward() must be static methods. They don't have access to 'self' - use ctx instead.

7The ctx Parameter

Context object that bridges forward and backward passes. Use it to save tensors and non-tensor data.

18save_for_backward()

Saves tensors that will be needed in backward(). These are retrieved via ctx.saved_tensors.

21Storing Non-Tensor Data

Non-tensor values (scalars, shapes, flags) can be stored as attributes on ctx.

35grad_output Parameter

The upstream gradient - the gradient of the final loss with respect to this function's output.

41Computing grad_input

Multiply upstream gradient by local gradient (chain rule). Here: dL/dx = dL/dy * dy/dx = grad_output * 2x.

45Return Value

Must return gradients for ALL inputs to forward(), in the same order. Return None for non-tensor or non-differentiable inputs.

49Using .apply()

Custom functions are called via .apply(), not directly. This ensures proper autograd bookkeeping.

46 lines without explanation
1import torch
2
3class MyFunction(torch.autograd.Function):
4    """Custom autograd function template."""
5
6    @staticmethod
7    def forward(ctx, x, *args):
8        """
9        Computes the output of the function.
10
11        Args:
12            ctx: Context object for saving tensors
13            x: Input tensor
14            *args: Additional arguments
15
16        Returns:
17            Output tensor(s)
18        """
19        # Save tensors needed for backward
20        ctx.save_for_backward(x)
21
22        # Store non-tensor data as attributes
23        ctx.some_value = 42
24
25        # Compute and return output
26        return x ** 2
27
28    @staticmethod
29    def backward(ctx, grad_output):
30        """
31        Computes gradients of the loss w.r.t. inputs.
32
33        Args:
34            ctx: Context object with saved tensors
35            grad_output: Gradient of loss w.r.t. output
36
37        Returns:
38            Gradients for each input (same order as forward)
39        """
40        # Retrieve saved tensors
41        x, = ctx.saved_tensors
42
43        # Compute local gradient and multiply by upstream
44        grad_input = grad_output * 2 * x
45
46        # Return gradient for each input
47        # Return None for inputs that don't need gradients
48        return grad_input
49
50# Usage: call with .apply()
51x = torch.tensor([3.0], requires_grad=True)
52y = MyFunction.apply(x)
53y.backward()
54print(x.grad)  # tensor([6.])

Critical Rules

  • Never call MyFunction(x) directly - always use MyFunction.apply(x)
  • backward() must return gradients for every input to forward()
  • Return None for inputs that don't need gradients (integers, strings, etc.)

Interactive: Custom Function Builder

Explore how different custom autograd functions work. Select a function to see its forward computation and backward gradient calculation in action.

Custom Autograd Function Builder

Simple squaring function - the foundation of understanding custom gradients
3.0
FORWARD PASS
Input (x)
3.00
y = x²
Square (x²)
Output (y)
?
ctx.save_for_backward(x=3.00)
BACKWARD PASS
grad_input
?
∂y/∂x = 2x
backward()
grad_output
1.0
PyTorch Implementation
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_output * 2 * x
1
2
3
4
Click animate to start

Key Insight: Custom autograd functions let you define forward() for computation and backward() for gradient calculation. The ctx object bridges both passes, allowing you to save values from forward that are needed in backward. This is essential for operations PyTorch doesn't support natively or when you need optimized gradient computation.


The forward() Method

The forward() method defines how your function computes its output. It receives a context object ctx and input tensor(s), and should return output tensor(s).

Key Responsibilities

  1. Compute the output: Apply your function to the inputs
  2. Save tensors for backward: Store any values needed for gradient computation
  3. Save non-tensor attributes: Store scalars, shapes, or flags on ctx
🐍forward_example.py
1class ScaledSigmoid(torch.autograd.Function):
2    @staticmethod
3    def forward(ctx, x, scale):
4        # Compute output
5        y = scale * torch.sigmoid(x)
6
7        # Save for backward (tensors)
8        ctx.save_for_backward(y)  # Save output, not input!
9
10        # Save non-tensor data
11        ctx.scale = scale
12
13        return y

Optimization: Save Output Instead of Input

For sigmoid, the gradient is σ(x)=σ(x)(1σ(x))=y(1y)\sigma'(x) = \sigma(x)(1-\sigma(x)) = y(1-y). Saving the output y is more efficient than saving x and recomputing sigmoid in backward.

The backward() Method

The backward() method computes gradients. It receives the contextctx (with saved tensors) and grad_output (the upstream gradient).

backward() Implementation
🐍backward_example.py
12Retrieving Saved Tensors

ctx.saved_tensors returns a tuple. Unpack it matching the order you saved them.

17Computing Local Gradient

This is the derivative of your function with respect to its input.

21Chain Rule Application

Multiply upstream gradient by local gradient to get the gradient to propagate.

25Returning Gradients

Return one gradient for each input to forward(). Return None for non-differentiable inputs like 'scale'.

22 lines without explanation
1class ScaledSigmoid(torch.autograd.Function):
2    @staticmethod
3    def forward(ctx, x, scale):
4        y = scale * torch.sigmoid(x)
5        ctx.save_for_backward(y)
6        ctx.scale = scale
7        return y
8
9    @staticmethod
10    def backward(ctx, grad_output):
11        # Retrieve saved tensors and attributes
12        y, = ctx.saved_tensors
13        scale = ctx.scale
14
15        # Compute local gradient
16        # d(scale * sigmoid(x))/dx = scale * sigmoid(x) * (1 - sigmoid(x))
17        # Since y = scale * sigmoid(x), sigmoid(x) = y / scale
18        sig = y / scale
19        local_grad = scale * sig * (1 - sig)
20
21        # Apply chain rule: grad_input = grad_output * local_grad
22        grad_x = grad_output * local_grad
23
24        # Return gradients for x and scale
25        # scale doesn't need gradients in this example
26        return grad_x, None

The Chain Rule in backward()

The fundamental equation you implement in every backward():

grad_input=grad_output×outputinput\texttt{grad\_input} = \texttt{grad\_output} \times \frac{\partial \texttt{output}}{\partial \texttt{input}}

Quick Check

If forward() takes 3 inputs (x, y, z), how many values must backward() return?


Interactive: Gradient Flow Visualization

Watch how gradients flow through custom functions during backpropagation. This visualization shows both the forward pass (computing values) and backward pass (computing gradients).

Gradient Flow Through Custom Functions

Forward Pass
Backward Pass
x2.0y3.0MySquareMyReLUmax(0,y)Adda + bLoss∇=1.0InputCustom FnValueGrad
Step 1 / 9
Ready to start. The graph shows x=2.0 and y=3.0 as inputs.

Gradient Flow: During backward pass, each custom function receives grad_output and multiplies it by its local derivative. MySquare multiplies by 2x (so grad=1×2×2=4), while MyReLU passes the gradient unchanged when input > 0.


Saving Tensors for Backward

The ctx.save_for_backward() method is how you pass information from forward to backward. There are important rules to follow:

What to Save

Save ThisExampleWhen
Input tensorsctx.save_for_backward(x)When gradient formula uses input
Output tensorsctx.save_for_backward(y)When gradient formula uses output
Intermediate tensorsctx.save_for_backward(hidden)When computed value is needed
Non-tensor datactx.alpha = alphaFor scalars, shapes, flags
🐍save_patterns.py
1class MyFunction(torch.autograd.Function):
2    @staticmethod
3    def forward(ctx, x, y, alpha, use_relu):
4        # Save tensors (must be tensors!)
5        ctx.save_for_backward(x, y)
6
7        # Save non-tensor data as attributes
8        ctx.alpha = alpha      # float
9        ctx.use_relu = use_relu  # bool
10
11        # You can also save marks for which inputs need gradients
12        # ctx.needs_input_grad tells you this in backward
13
14        z = alpha * x + y
15        if use_relu:
16            z = torch.relu(z)
17        return z
18
19    @staticmethod
20    def backward(ctx, grad_output):
21        # Retrieve tensors
22        x, y = ctx.saved_tensors
23
24        # Retrieve non-tensor data
25        alpha = ctx.alpha
26
27        # Check which inputs need gradients
28        need_x_grad, need_y_grad, _, _ = ctx.needs_input_grad
29
30        grad_x = grad_output * alpha if need_x_grad else None
31        grad_y = grad_output if need_y_grad else None
32
33        return grad_x, grad_y, None, None  # None for alpha and use_relu

Memory Considerations

Saved tensors are kept in memory until backward() is called. For very large tensors, consider whether you can recompute values in backward() instead of saving them. Trade off memory vs. computation based on your use case.

Functions with Multiple Inputs

Many operations take multiple inputs. You must compute gradients for each:

Function with Multiple Inputs
🐍multiple_inputs.py
5Multiple Input Parameters

forward() can take any number of inputs. Each will need a corresponding gradient in backward().

17Gradient for x1

Since y = w1*x1 + w2*x2, partial derivative dy/dx1 = w1.

20Gradient for x2

Similarly, dy/dx2 = w2.

23None for Non-Differentiable Inputs

w1 and w2 are scalars that don't participate in the computation graph, so return None.

28 lines without explanation
1class WeightedSum(torch.autograd.Function):
2    """Computes y = w1*x1 + w2*x2"""
3
4    @staticmethod
5    def forward(ctx, x1, x2, w1, w2):
6        ctx.save_for_backward(x1, x2)
7        ctx.w1 = w1
8        ctx.w2 = w2
9        return w1 * x1 + w2 * x2
10
11    @staticmethod
12    def backward(ctx, grad_output):
13        x1, x2 = ctx.saved_tensors
14        w1, w2 = ctx.w1, ctx.w2
15
16        # Gradient for each input
17        # dy/dx1 = w1, so grad_x1 = grad_output * w1
18        grad_x1 = grad_output * w1
19
20        # dy/dx2 = w2
21        grad_x2 = grad_output * w2
22
23        # w1 and w2 are scalars, don't need gradients
24        return grad_x1, grad_x2, None, None
25
26# Usage
27x1 = torch.tensor([1.0, 2.0], requires_grad=True)
28x2 = torch.tensor([3.0, 4.0], requires_grad=True)
29y = WeightedSum.apply(x1, x2, 0.3, 0.7)
30y.sum().backward()
31print(x1.grad)  # tensor([0.3, 0.3])
32print(x2.grad)  # tensor([0.7, 0.7])

Straight-Through Estimator

Some operations have zero gradients almost everywhere, which blocks gradient flow:

  • sign(x) - derivative is 0 everywhere except undefined at 0
  • round(x) - derivative is 0 everywhere (step function)
  • floor(x), ceil(x) - derivative is 0
  • Quantization functions - derivative is 0

The Straight-Through Estimator (STE) solves this by using a different function during backward than forward:

Forward: y=f(x)(non-differentiable)\text{Forward: } y = f(x) \quad \text{(non-differentiable)}
Backward: yxg(x)(differentiable proxy)\text{Backward: } \frac{\partial y}{\partial x} \approx g'(x) \quad \text{(differentiable proxy)}
🐍ste.py
1class SignSTE(torch.autograd.Function):
2    """Sign function with Straight-Through Estimator."""
3
4    @staticmethod
5    def forward(ctx, x):
6        # Forward: actual sign function
7        return torch.sign(x)
8
9    @staticmethod
10    def backward(ctx, grad_output):
11        # Backward: pretend it's identity (STE)
12        return grad_output  # Just pass through!
13
14# More sophisticated: clipped STE
15class SignClippedSTE(torch.autograd.Function):
16    @staticmethod
17    def forward(ctx, x):
18        ctx.save_for_backward(x)
19        return torch.sign(x)
20
21    @staticmethod
22    def backward(ctx, grad_output):
23        x, = ctx.saved_tensors
24        # Only pass gradient where |x| <= 1
25        grad_input = grad_output.clone()
26        grad_input[x.abs() > 1] = 0
27        return grad_input

STE in Practice

The Straight-Through Estimator is widely used in:
  • Binary Neural Networks - weights are +1/-1 but train with full precision
  • Quantization-Aware Training - simulate low-bit inference
  • Discrete VAEs - sample discrete codes but backprop through
  • Gumbel-Softmax - a temperature-based STE variant

Interactive: STE Demo

Compare different approaches to handling non-differentiable functions. See how naive zero gradients block learning while STE enables gradient flow.

Straight-Through Estimator (STE)

The Problem

Some functions like sign(x), round(x), or quantize(x) have zero gradient almost everywhere. This kills gradient flow during backpropagation. The Straight-Through Estimator solves this by using a different function during backward pass.

0.70
xysign(x)
Forward Output
sign(0.70) = 1
Backward Gradient
∂y/∂x = 0.000

Zero gradient! No learning signal passes through.

PyTorch Implementation
class Sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return torch.sign(x)

    @staticmethod
    def backward(ctx, grad_output):
        # ❌ True derivative is zero!
        return torch.zeros_like(grad_output)
Binary Neural Networks

Weights are binarized to +1/-1 for efficient inference, but use STE for training with full-precision gradients.

Quantization-Aware Training

Simulate low-precision inference during training while maintaining gradient flow for optimization.

Discrete Latent Variables

VAEs with discrete codes use STE (or Gumbel-Softmax) to enable backpropagation through sampling.

Key Insight: The Straight-Through Estimator is a "gradient hack" - it uses one function for forward (e.g., sign) and a different function for backward (e.g., identity). This biased gradient estimate is surprisingly effective in practice, enabling training of models that would otherwise have zero gradients.


Higher-Order Gradients

Sometimes you need gradients of gradients (Hessians, second derivatives). For this, your backward() must itself be differentiable:

Double Backward Support
🐍double_backward.py
13Differentiable Operations in backward()

For double backward, the operations in backward() must themselves be differentiable. Avoid in-place ops.

22create_graph=True

This tells autograd to build a graph for the gradient computation, enabling second derivatives.

26Second Derivative

Computing the gradient of the gradient gives d²y/dx² = d(2x)/dx = 2.

26 lines without explanation
1class DoubleBackwardSquare(torch.autograd.Function):
2    """Square function supporting second derivatives."""
3
4    @staticmethod
5    def forward(ctx, x):
6        ctx.save_for_backward(x)
7        return x ** 2
8
9    @staticmethod
10    def backward(ctx, grad_output):
11        x, = ctx.saved_tensors
12
13        # For double backward, we need this computation to be
14        # part of the graph. Use differentiable operations!
15        grad_input = grad_output * 2 * x
16
17        return grad_input
18
19# Test double backward
20x = torch.tensor([3.0], requires_grad=True)
21y = DoubleBackwardSquare.apply(x)
22
23# First derivative
24grad, = torch.autograd.grad(y, x, create_graph=True)
25print(f"dy/dx = {grad.item()}")  # 6.0
26
27# Second derivative
28grad2, = torch.autograd.grad(grad, x)
29print(f"d²y/dx² = {grad2.item()}")  # 2.0

If your backward doesn't support higher-order gradients, use the@once_differentiable decorator:

🐍once_differentiable.py
1from torch.autograd.function import once_differentiable
2
3class MyFunction(torch.autograd.Function):
4    @staticmethod
5    def forward(ctx, x):
6        ctx.save_for_backward(x)
7        return x ** 2
8
9    @staticmethod
10    @once_differentiable  # Disable double backward
11    def backward(ctx, grad_output):
12        x, = ctx.saved_tensors
13        # Can use non-differentiable ops here
14        return grad_output * 2 * x

Practical Examples

Example 1: Clipped ReLU

🐍clipped_relu.py
1class ClippedReLU(torch.autograd.Function):
2    """ReLU with maximum value clipping: min(max(0, x), clip_val)"""
3
4    @staticmethod
5    def forward(ctx, x, clip_val=6.0):
6        ctx.save_for_backward(x)
7        ctx.clip_val = clip_val
8        return x.clamp(min=0, max=clip_val)
9
10    @staticmethod
11    def backward(ctx, grad_output):
12        x, = ctx.saved_tensors
13        clip_val = ctx.clip_val
14
15        # Gradient is 1 where 0 < x < clip_val, else 0
16        grad_input = grad_output.clone()
17        grad_input[x < 0] = 0
18        grad_input[x > clip_val] = 0
19
20        return grad_input, None
21
22# Usage
23x = torch.tensor([-1., 2., 5., 8.], requires_grad=True)
24y = ClippedReLU.apply(x, 6.0)
25print(y)  # tensor([0., 2., 5., 6.])
26y.sum().backward()
27print(x.grad)  # tensor([0., 1., 1., 0.])

Example 2: Soft Threshold (Proximal Operator)

🐍soft_threshold.py
1class SoftThreshold(torch.autograd.Function):
2    """Soft thresholding: sign(x) * max(|x| - threshold, 0)
3    Used in sparse optimization (LASSO, etc.)
4    """
5
6    @staticmethod
7    def forward(ctx, x, threshold):
8        ctx.save_for_backward(x)
9        ctx.threshold = threshold
10
11        return torch.sign(x) * torch.relu(x.abs() - threshold)
12
13    @staticmethod
14    def backward(ctx, grad_output):
15        x, = ctx.saved_tensors
16        threshold = ctx.threshold
17
18        # Gradient is 1 where |x| > threshold, else 0
19        grad_input = grad_output.clone()
20        grad_input[x.abs() <= threshold] = 0
21
22        return grad_input, None

Example 3: Gumbel-Softmax (Differentiable Sampling)

🐍gumbel_softmax.py
1class GumbelSoftmax(torch.autograd.Function):
2    """Gumbel-Softmax for differentiable discrete sampling."""
3
4    @staticmethod
5    def forward(ctx, logits, temperature, hard=True):
6        # Sample from Gumbel(0, 1)
7        gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
8
9        # Apply temperature-scaled softmax
10        y_soft = torch.softmax((logits + gumbels) / temperature, dim=-1)
11
12        if hard:
13            # Straight-through: hard in forward, soft in backward
14            index = y_soft.argmax(dim=-1, keepdim=True)
15            y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0)
16            y = y_hard - y_soft.detach() + y_soft  # STE trick!
17        else:
18            y = y_soft
19
20        ctx.save_for_backward(y_soft)
21        return y
22
23    @staticmethod
24    def backward(ctx, grad_output):
25        # Gradient flows through the soft version
26        return grad_output, None, None

Performance Optimization

Custom functions can be optimized for speed:

1. Use ctx.needs_input_grad

🐍needs_input_grad.py
1@staticmethod
2def backward(ctx, grad_output):
3    x, y = ctx.saved_tensors
4    grad_x, grad_y = None, None
5
6    # Only compute gradients that are needed
7    if ctx.needs_input_grad[0]:
8        grad_x = grad_output * y
9
10    if ctx.needs_input_grad[1]:
11        grad_y = grad_output * x
12
13    return grad_x, grad_y

2. Fuse Operations

🐍fused_ops.py
1class FusedAddReLU(torch.autograd.Function):
2    """Fused add + ReLU for efficiency."""
3
4    @staticmethod
5    def forward(ctx, x, y):
6        z = x + y
7        mask = z > 0
8        ctx.save_for_backward(mask)
9        return z.clamp(min=0)
10
11    @staticmethod
12    def backward(ctx, grad_output):
13        mask, = ctx.saved_tensors
14        grad = grad_output * mask.float()
15        return grad, grad  # Same gradient for both inputs

3. CUDA Kernels

For maximum performance, implement custom CUDA kernels:

🐍cuda_extension.py
1# In custom_ops.cpp
2#include <torch/extension.h>
3
4torch::Tensor my_custom_forward_cuda(torch::Tensor x);
5torch::Tensor my_custom_backward_cuda(torch::Tensor grad, torch::Tensor x);
6
7# In Python
8class MyCUDAFunction(torch.autograd.Function):
9    @staticmethod
10    def forward(ctx, x):
11        ctx.save_for_backward(x)
12        return my_custom_extension.forward(x)
13
14    @staticmethod
15    def backward(ctx, grad_output):
16        x, = ctx.saved_tensors
17        return my_custom_extension.backward(grad_output, x)

Debugging Custom Functions

1. Gradient Check

Use numerical gradient checking to verify your backward() is correct:

🐍gradcheck.py
1from torch.autograd import gradcheck
2
3# Create input with double precision for accuracy
4x = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
5
6# Check gradients
7test_passed = gradcheck(MyFunction.apply, (x,), eps=1e-6, atol=1e-4)
8print(f"Gradient check: {'PASSED' if test_passed else 'FAILED'}")

2. Common Errors and Fixes

ErrorCauseFix
RuntimeError: One of the differentiated Tensors...In-place operation on saved tensorClone tensor before in-place ops
RuntimeError: Trying to backward through graph second timeGraph freed after first backwardUse retain_graph=True
Wrong number of gradients returnedbackward() returns wrong countMatch count to forward() inputs
Gradient doesn't match numerical gradientBug in local gradient formulaCheck math, use gradcheck
🐍debug_tips.py
1# Tip 1: Print shapes during development
2def backward(ctx, grad_output):
3    print(f"grad_output shape: {grad_output.shape}")
4    x, = ctx.saved_tensors
5    print(f"saved x shape: {x.shape}")
6    ...
7
8# Tip 2: Check for NaN/Inf
9def backward(ctx, grad_output):
10    if torch.isnan(grad_output).any():
11        print("WARNING: NaN in grad_output!")
12    ...
13
14# Tip 3: Test with simple inputs first
15x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
16y = MyFunction.apply(x)
17y.sum().backward()
18print(x.grad)  # Should match manual calculation

Model Export and Deployment

Once you've built and trained your model, you need to deploy it efficiently. PyTorch provides several tools for model export and optimization.

TorchScript: JIT Compilation

TorchScript converts your PyTorch model to a serialized, optimized format that can run without Python.

🐍torchscript.py
1import torch
2
3class MyModel(torch.nn.Module):
4    def __init__(self):
5        super().__init__()
6        self.linear = torch.nn.Linear(10, 5)
7
8    def forward(self, x):
9        return self.linear(x).relu()
10
11model = MyModel()
12model.eval()
13
14# Method 1: Tracing (records operations with example input)
15example_input = torch.randn(1, 10)
16traced_model = torch.jit.trace(model, example_input)
17
18# Method 2: Scripting (parses Python code directly)
19scripted_model = torch.jit.script(model)
20
21# Save for deployment (no Python needed to load)
22traced_model.save("model_traced.pt")
23scripted_model.save("model_scripted.pt")
24
25# Load in C++ or Python
26loaded = torch.jit.load("model_traced.pt")
27output = loaded(torch.randn(1, 10))
MethodProsConsUse When
trace()Simple, handles any PythonMisses control flowStatic models, no if/for on tensors
script()Captures control flowLimited Python supportDynamic shapes, conditionals

ONNX Export

ONNX (Open Neural Network Exchange) is an open format for representing ML models, enabling deployment across different frameworks and hardware.

🐍onnx_export.py
1import torch
2import torch.onnx
3
4model = MyModel()
5model.eval()
6dummy_input = torch.randn(1, 10)
7
8# Export to ONNX
9torch.onnx.export(
10    model,
11    dummy_input,
12    "model.onnx",
13    input_names=["input"],
14    output_names=["output"],
15    dynamic_axes={
16        "input": {0: "batch_size"},   # Variable batch size
17        "output": {0: "batch_size"},
18    },
19    opset_version=17,  # Use recent opset for better op support
20)
21
22# Verify the exported model
23import onnx
24onnx_model = onnx.load("model.onnx")
25onnx.checker.check_model(onnx_model)
26
27# Run with ONNX Runtime (pip install onnxruntime)
28import onnxruntime as ort
29session = ort.InferenceSession("model.onnx")
30outputs = session.run(None, {"input": dummy_input.numpy()})

Quantization

Quantization reduces model size and increases inference speed by using lower-precision arithmetic (int8 instead of float32).

🐍quantization.py
1import torch
2from torch.quantization import quantize_dynamic, prepare, convert
3
4model = MyModel()
5model.eval()
6
7# Dynamic Quantization (easiest, for CPU)
8# Quantizes weights statically, activations dynamically
9quantized_model = quantize_dynamic(
10    model,
11    {torch.nn.Linear},  # Layers to quantize
12    dtype=torch.qint8
13)
14
15# Static Quantization (better performance, requires calibration)
16model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
17model_prepared = prepare(model)
18
19# Calibrate with representative data
20with torch.no_grad():
21    for data in calibration_loader:
22        model_prepared(data)
23
24model_quantized = convert(model_prepared)
25
26# Quantization-Aware Training (best accuracy)
27model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
28model_qat = prepare_qat(model.train())
29# Train as usual, then convert
30model_quantized = convert(model_qat.eval())
31
32# Compare sizes
33print(f"Original: {get_model_size(model):.2f} MB")
34print(f"Quantized: {get_model_size(quantized_model):.2f} MB")
MethodAccuracySpeedEffort
DynamicGood2-4x fasterOne line of code
StaticBetter2-4x fasterNeeds calibration data
QATBest2-4x fasterRequires retraining

Deployment Stack

For production deployment, consider combining these tools: use torch.compile for training speedup, TorchScript or ONNX for export, and quantization for inference optimization. Many deployment platforms (TensorRT, OpenVINO, Core ML) accept ONNX models directly.

Knowledge Check

Test your understanding of custom autograd functions with this comprehensive quiz.

Custom Autograd Knowledge Check

Question 1 of 10

What are the two required static methods in a custom autograd Function?


Summary

In this section, we learned how to extend PyTorch's autograd system with custom operations:

ConceptKey Points
torch.autograd.FunctionBase class for custom differentiable operations
forward(ctx, ...)Compute output, save tensors for backward
backward(ctx, grad_output)Compute grad_input = grad_output × local_gradient
ctx.save_for_backward()Store tensors needed in backward pass
.apply()How to call custom functions (not direct call)
Straight-Through EstimatorEnable gradients for non-differentiable ops
Higher-order gradientsUse differentiable ops in backward()
gradcheck()Verify gradients numerically

Key Takeaways

  1. Custom functions teach PyTorch new operations by defining forward (computation) and backward (gradient) passes.
  2. The chain rule is central: multiply upstream gradient by local gradient to get downstream gradient.
  3. Save efficiently: only save what you need for backward, and consider saving outputs vs inputs based on the gradient formula.
  4. STE enables impossible operations: use different functions for forward and backward to handle non-differentiable ops.
  5. Always verify with gradcheck: numerical gradient checking catches bugs in your gradient formulas.

Exercises

Conceptual Questions

  1. Explain why ctx.save_for_backward() only accepts tensors. How do you save non-tensor data?
  2. Why does backward() receive grad_output instead of computing the gradient from scratch?
  3. Describe a scenario where the Straight-Through Estimator produces biased gradients. Why might this still lead to successful training?

Coding Exercises

  1. GELU Activation: Implement the GELU activation function as a custom autograd function. GELU is defined as:
    GELU(x)=xΦ(x)=x12[1+erf(x2)]\text{GELU}(x) = x \cdot \Phi(x) = x \cdot \frac{1}{2}\left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]
    Verify with gradcheck and compare to torch.nn.functional.gelu.
  2. Swish/SiLU Activation: Implement Swish: f(x)=xσ(x)f(x) = x \cdot \sigma(x). Implement efficiently by saving the output in forward.
  3. Batch Normalization: Implement a simplified batch normalization as a custom function (without learnable parameters):
    x^=xμσ2+ϵ\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}}
    Compute gradients with respect to input x.
  4. Differentiable Sorting: Implement a function that computes soft ranks of elements. Use a temperature parameter to control softness.

Challenge Exercise

Implement a Custom Attention Mechanism: Create a custom autograd function for scaled dot-product attention:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

Requirements:

  • Implement efficient forward pass that saves only necessary tensors
  • Implement backward pass for all three inputs (Q, K, V)
  • Support optional attention mask
  • Verify gradients match PyTorch's built-in attention
🐍challenge_starter.py
1class ScaledDotProductAttention(torch.autograd.Function):
2    @staticmethod
3    def forward(ctx, Q, K, V, scale, mask=None):
4        # Q: (batch, heads, seq_len, d_k)
5        # K: (batch, heads, seq_len, d_k)
6        # V: (batch, heads, seq_len, d_v)
7
8        # Compute attention scores
9        scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
10
11        if mask is not None:
12            scores = scores.masked_fill(mask == 0, float('-inf'))
13
14        attn_weights = torch.softmax(scores, dim=-1)
15        output = torch.matmul(attn_weights, V)
16
17        # Save what you need for backward
18        ctx.save_for_backward(Q, K, V, attn_weights)
19        ctx.scale = scale
20
21        return output
22
23    @staticmethod
24    def backward(ctx, grad_output):
25        Q, K, V, attn_weights = ctx.saved_tensors
26        scale = ctx.scale
27
28        # TODO: Implement gradients for Q, K, V
29        # Hint: Start with grad_V (easiest), then grad_attn_weights,
30        # then propagate through softmax, then to Q and K
31
32        grad_Q = ...
33        grad_K = ...
34        grad_V = ...
35
36        return grad_Q, grad_K, grad_V, None, None

Implementation Hint

The gradient through softmax is: if p=softmax(s)p = \text{softmax}(s), then Lsi=pi(LpijpjLpj)\frac{\partial L}{\partial s_i} = p_i \cdot \left(\frac{\partial L}{\partial p_i} - \sum_j p_j \frac{\partial L}{\partial p_j}\right). This can be written as s=p(p(pp))\nabla_s = p \odot (\nabla_p - (\nabla_p \cdot p)).

Congratulations on completing Chapter 4: PyTorch Fundamentals! You now have a deep understanding of tensors, operations, GPU computing, autograd, and custom functions. In the next chapter, we'll use these foundations to build neural network layers and explore the nn.Module API.