Learning Objectives
By the end of this section, you will be able to:
- Implement custom autograd Functions using
torch.autograd.Function - Define forward and backward passes with correct gradient computation
- Use ctx.save_for_backward() to store tensors needed for gradients
- Handle multiple inputs and outputs in custom functions
- Apply the Straight-Through Estimator for non-differentiable operations
- Implement higher-order gradients with double backward support
- Debug and optimize custom autograd functions
Why This Matters: While PyTorch provides hundreds of differentiable operations out of the box, real-world research and applications often require custom operations - whether for novel activation functions, specialized loss functions, efficiency optimizations, or operations that PyTorch doesn't support natively. Mastering custom autograd functions unlocks the full power of automatic differentiation.
The Big Picture
In the previous section, we learned that PyTorch's autograd system automatically computes gradients by recording operations in a computational graph. Each built-in operation (like torch.add, torch.matmul) is implemented as an autograd.Function with both forward and backward passes defined.
When we create a custom autograd function, we're essentially teaching PyTorch:
- How to compute the output given inputs (forward pass)
- How to compute gradients given the upstream gradient (backward pass)
The Mathematical Foundation
Consider a function in a larger computation graph. During backpropagation, we receive (the gradient of the loss with respect to our output). We need to compute:
The term is the upstream gradient (provided to us as grad_output). The term is the local gradient (we must compute this). Our job in backward()is to return their product.
When Do You Need Custom Functions?
| Use Case | Example | Why Custom |
|---|---|---|
| Novel operations | Custom activation functions | PyTorch doesn't have built-in support |
| Efficiency | Fused operations | Combine multiple ops for speed |
| Non-differentiable ops | Quantization, rounding | Need STE or custom gradients |
| Numerical stability | Log-sum-exp variants | Control gradient computation exactly |
| Research | New layer types | Implement paper algorithms precisely |
Anatomy of a Custom Function
Every custom autograd function inherits from torch.autograd.Functionand must implement two static methods:
Critical Rules
- Never call
MyFunction(x)directly - always useMyFunction.apply(x) backward()must return gradients for every input toforward()- Return
Nonefor inputs that don't need gradients (integers, strings, etc.)
Interactive: Custom Function Builder
Explore how different custom autograd functions work. Select a function to see its forward computation and backward gradient calculation in action.
Custom Autograd Function Builder
class Square(torch.autograd.Function):@staticmethod def forward(ctx, x): ctx.save_for_backward(x) return x ** 2@staticmethod def backward(ctx, grad_output): x, = ctx.saved_tensors return grad_output * 2 * xKey Insight: Custom autograd functions let you define forward() for computation and backward() for gradient calculation. The ctx object bridges both passes, allowing you to save values from forward that are needed in backward. This is essential for operations PyTorch doesn't support natively or when you need optimized gradient computation.
The forward() Method
The forward() method defines how your function computes its output. It receives a context object ctx and input tensor(s), and should return output tensor(s).
Key Responsibilities
- Compute the output: Apply your function to the inputs
- Save tensors for backward: Store any values needed for gradient computation
- Save non-tensor attributes: Store scalars, shapes, or flags on ctx
1class ScaledSigmoid(torch.autograd.Function):
2 @staticmethod
3 def forward(ctx, x, scale):
4 # Compute output
5 y = scale * torch.sigmoid(x)
6
7 # Save for backward (tensors)
8 ctx.save_for_backward(y) # Save output, not input!
9
10 # Save non-tensor data
11 ctx.scale = scale
12
13 return yOptimization: Save Output Instead of Input
y is more efficient than saving x and recomputing sigmoid in backward.The backward() Method
The backward() method computes gradients. It receives the contextctx (with saved tensors) and grad_output (the upstream gradient).
The Chain Rule in backward()
The fundamental equation you implement in every backward():
Quick Check
If forward() takes 3 inputs (x, y, z), how many values must backward() return?
Interactive: Gradient Flow Visualization
Watch how gradients flow through custom functions during backpropagation. This visualization shows both the forward pass (computing values) and backward pass (computing gradients).
Gradient Flow Through Custom Functions
Gradient Flow: During backward pass, each custom function receives grad_output and multiplies it by its local derivative. MySquare multiplies by 2x (so grad=1×2×2=4), while MyReLU passes the gradient unchanged when input > 0.
Saving Tensors for Backward
The ctx.save_for_backward() method is how you pass information from forward to backward. There are important rules to follow:
What to Save
| Save This | Example | When |
|---|---|---|
| Input tensors | ctx.save_for_backward(x) | When gradient formula uses input |
| Output tensors | ctx.save_for_backward(y) | When gradient formula uses output |
| Intermediate tensors | ctx.save_for_backward(hidden) | When computed value is needed |
| Non-tensor data | ctx.alpha = alpha | For scalars, shapes, flags |
1class MyFunction(torch.autograd.Function):
2 @staticmethod
3 def forward(ctx, x, y, alpha, use_relu):
4 # Save tensors (must be tensors!)
5 ctx.save_for_backward(x, y)
6
7 # Save non-tensor data as attributes
8 ctx.alpha = alpha # float
9 ctx.use_relu = use_relu # bool
10
11 # You can also save marks for which inputs need gradients
12 # ctx.needs_input_grad tells you this in backward
13
14 z = alpha * x + y
15 if use_relu:
16 z = torch.relu(z)
17 return z
18
19 @staticmethod
20 def backward(ctx, grad_output):
21 # Retrieve tensors
22 x, y = ctx.saved_tensors
23
24 # Retrieve non-tensor data
25 alpha = ctx.alpha
26
27 # Check which inputs need gradients
28 need_x_grad, need_y_grad, _, _ = ctx.needs_input_grad
29
30 grad_x = grad_output * alpha if need_x_grad else None
31 grad_y = grad_output if need_y_grad else None
32
33 return grad_x, grad_y, None, None # None for alpha and use_reluMemory Considerations
Functions with Multiple Inputs
Many operations take multiple inputs. You must compute gradients for each:
Straight-Through Estimator
Some operations have zero gradients almost everywhere, which blocks gradient flow:
sign(x)- derivative is 0 everywhere except undefined at 0round(x)- derivative is 0 everywhere (step function)floor(x),ceil(x)- derivative is 0- Quantization functions - derivative is 0
The Straight-Through Estimator (STE) solves this by using a different function during backward than forward:
1class SignSTE(torch.autograd.Function):
2 """Sign function with Straight-Through Estimator."""
3
4 @staticmethod
5 def forward(ctx, x):
6 # Forward: actual sign function
7 return torch.sign(x)
8
9 @staticmethod
10 def backward(ctx, grad_output):
11 # Backward: pretend it's identity (STE)
12 return grad_output # Just pass through!
13
14# More sophisticated: clipped STE
15class SignClippedSTE(torch.autograd.Function):
16 @staticmethod
17 def forward(ctx, x):
18 ctx.save_for_backward(x)
19 return torch.sign(x)
20
21 @staticmethod
22 def backward(ctx, grad_output):
23 x, = ctx.saved_tensors
24 # Only pass gradient where |x| <= 1
25 grad_input = grad_output.clone()
26 grad_input[x.abs() > 1] = 0
27 return grad_inputSTE in Practice
- Binary Neural Networks - weights are +1/-1 but train with full precision
- Quantization-Aware Training - simulate low-bit inference
- Discrete VAEs - sample discrete codes but backprop through
- Gumbel-Softmax - a temperature-based STE variant
Interactive: STE Demo
Compare different approaches to handling non-differentiable functions. See how naive zero gradients block learning while STE enables gradient flow.
Straight-Through Estimator (STE)
Some functions like sign(x), round(x), or quantize(x) have zero gradient almost everywhere. This kills gradient flow during backpropagation. The Straight-Through Estimator solves this by using a different function during backward pass.
Zero gradient! No learning signal passes through.
class Sign(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return torch.sign(x)
@staticmethod
def backward(ctx, grad_output):
# ❌ True derivative is zero!
return torch.zeros_like(grad_output)Weights are binarized to +1/-1 for efficient inference, but use STE for training with full-precision gradients.
Simulate low-precision inference during training while maintaining gradient flow for optimization.
VAEs with discrete codes use STE (or Gumbel-Softmax) to enable backpropagation through sampling.
Key Insight: The Straight-Through Estimator is a "gradient hack" - it uses one function for forward (e.g., sign) and a different function for backward (e.g., identity). This biased gradient estimate is surprisingly effective in practice, enabling training of models that would otherwise have zero gradients.
Higher-Order Gradients
Sometimes you need gradients of gradients (Hessians, second derivatives). For this, your backward() must itself be differentiable:
If your backward doesn't support higher-order gradients, use the@once_differentiable decorator:
1from torch.autograd.function import once_differentiable
2
3class MyFunction(torch.autograd.Function):
4 @staticmethod
5 def forward(ctx, x):
6 ctx.save_for_backward(x)
7 return x ** 2
8
9 @staticmethod
10 @once_differentiable # Disable double backward
11 def backward(ctx, grad_output):
12 x, = ctx.saved_tensors
13 # Can use non-differentiable ops here
14 return grad_output * 2 * xPractical Examples
Example 1: Clipped ReLU
1class ClippedReLU(torch.autograd.Function):
2 """ReLU with maximum value clipping: min(max(0, x), clip_val)"""
3
4 @staticmethod
5 def forward(ctx, x, clip_val=6.0):
6 ctx.save_for_backward(x)
7 ctx.clip_val = clip_val
8 return x.clamp(min=0, max=clip_val)
9
10 @staticmethod
11 def backward(ctx, grad_output):
12 x, = ctx.saved_tensors
13 clip_val = ctx.clip_val
14
15 # Gradient is 1 where 0 < x < clip_val, else 0
16 grad_input = grad_output.clone()
17 grad_input[x < 0] = 0
18 grad_input[x > clip_val] = 0
19
20 return grad_input, None
21
22# Usage
23x = torch.tensor([-1., 2., 5., 8.], requires_grad=True)
24y = ClippedReLU.apply(x, 6.0)
25print(y) # tensor([0., 2., 5., 6.])
26y.sum().backward()
27print(x.grad) # tensor([0., 1., 1., 0.])Example 2: Soft Threshold (Proximal Operator)
1class SoftThreshold(torch.autograd.Function):
2 """Soft thresholding: sign(x) * max(|x| - threshold, 0)
3 Used in sparse optimization (LASSO, etc.)
4 """
5
6 @staticmethod
7 def forward(ctx, x, threshold):
8 ctx.save_for_backward(x)
9 ctx.threshold = threshold
10
11 return torch.sign(x) * torch.relu(x.abs() - threshold)
12
13 @staticmethod
14 def backward(ctx, grad_output):
15 x, = ctx.saved_tensors
16 threshold = ctx.threshold
17
18 # Gradient is 1 where |x| > threshold, else 0
19 grad_input = grad_output.clone()
20 grad_input[x.abs() <= threshold] = 0
21
22 return grad_input, NoneExample 3: Gumbel-Softmax (Differentiable Sampling)
1class GumbelSoftmax(torch.autograd.Function):
2 """Gumbel-Softmax for differentiable discrete sampling."""
3
4 @staticmethod
5 def forward(ctx, logits, temperature, hard=True):
6 # Sample from Gumbel(0, 1)
7 gumbels = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10) + 1e-10)
8
9 # Apply temperature-scaled softmax
10 y_soft = torch.softmax((logits + gumbels) / temperature, dim=-1)
11
12 if hard:
13 # Straight-through: hard in forward, soft in backward
14 index = y_soft.argmax(dim=-1, keepdim=True)
15 y_hard = torch.zeros_like(logits).scatter_(-1, index, 1.0)
16 y = y_hard - y_soft.detach() + y_soft # STE trick!
17 else:
18 y = y_soft
19
20 ctx.save_for_backward(y_soft)
21 return y
22
23 @staticmethod
24 def backward(ctx, grad_output):
25 # Gradient flows through the soft version
26 return grad_output, None, NonePerformance Optimization
Custom functions can be optimized for speed:
1. Use ctx.needs_input_grad
1@staticmethod
2def backward(ctx, grad_output):
3 x, y = ctx.saved_tensors
4 grad_x, grad_y = None, None
5
6 # Only compute gradients that are needed
7 if ctx.needs_input_grad[0]:
8 grad_x = grad_output * y
9
10 if ctx.needs_input_grad[1]:
11 grad_y = grad_output * x
12
13 return grad_x, grad_y2. Fuse Operations
1class FusedAddReLU(torch.autograd.Function):
2 """Fused add + ReLU for efficiency."""
3
4 @staticmethod
5 def forward(ctx, x, y):
6 z = x + y
7 mask = z > 0
8 ctx.save_for_backward(mask)
9 return z.clamp(min=0)
10
11 @staticmethod
12 def backward(ctx, grad_output):
13 mask, = ctx.saved_tensors
14 grad = grad_output * mask.float()
15 return grad, grad # Same gradient for both inputs3. CUDA Kernels
For maximum performance, implement custom CUDA kernels:
1# In custom_ops.cpp
2#include <torch/extension.h>
3
4torch::Tensor my_custom_forward_cuda(torch::Tensor x);
5torch::Tensor my_custom_backward_cuda(torch::Tensor grad, torch::Tensor x);
6
7# In Python
8class MyCUDAFunction(torch.autograd.Function):
9 @staticmethod
10 def forward(ctx, x):
11 ctx.save_for_backward(x)
12 return my_custom_extension.forward(x)
13
14 @staticmethod
15 def backward(ctx, grad_output):
16 x, = ctx.saved_tensors
17 return my_custom_extension.backward(grad_output, x)Debugging Custom Functions
1. Gradient Check
Use numerical gradient checking to verify your backward() is correct:
1from torch.autograd import gradcheck
2
3# Create input with double precision for accuracy
4x = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
5
6# Check gradients
7test_passed = gradcheck(MyFunction.apply, (x,), eps=1e-6, atol=1e-4)
8print(f"Gradient check: {'PASSED' if test_passed else 'FAILED'}")2. Common Errors and Fixes
| Error | Cause | Fix |
|---|---|---|
| RuntimeError: One of the differentiated Tensors... | In-place operation on saved tensor | Clone tensor before in-place ops |
| RuntimeError: Trying to backward through graph second time | Graph freed after first backward | Use retain_graph=True |
| Wrong number of gradients returned | backward() returns wrong count | Match count to forward() inputs |
| Gradient doesn't match numerical gradient | Bug in local gradient formula | Check math, use gradcheck |
1# Tip 1: Print shapes during development
2def backward(ctx, grad_output):
3 print(f"grad_output shape: {grad_output.shape}")
4 x, = ctx.saved_tensors
5 print(f"saved x shape: {x.shape}")
6 ...
7
8# Tip 2: Check for NaN/Inf
9def backward(ctx, grad_output):
10 if torch.isnan(grad_output).any():
11 print("WARNING: NaN in grad_output!")
12 ...
13
14# Tip 3: Test with simple inputs first
15x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
16y = MyFunction.apply(x)
17y.sum().backward()
18print(x.grad) # Should match manual calculationModel Export and Deployment
Once you've built and trained your model, you need to deploy it efficiently. PyTorch provides several tools for model export and optimization.
TorchScript: JIT Compilation
TorchScript converts your PyTorch model to a serialized, optimized format that can run without Python.
1import torch
2
3class MyModel(torch.nn.Module):
4 def __init__(self):
5 super().__init__()
6 self.linear = torch.nn.Linear(10, 5)
7
8 def forward(self, x):
9 return self.linear(x).relu()
10
11model = MyModel()
12model.eval()
13
14# Method 1: Tracing (records operations with example input)
15example_input = torch.randn(1, 10)
16traced_model = torch.jit.trace(model, example_input)
17
18# Method 2: Scripting (parses Python code directly)
19scripted_model = torch.jit.script(model)
20
21# Save for deployment (no Python needed to load)
22traced_model.save("model_traced.pt")
23scripted_model.save("model_scripted.pt")
24
25# Load in C++ or Python
26loaded = torch.jit.load("model_traced.pt")
27output = loaded(torch.randn(1, 10))| Method | Pros | Cons | Use When |
|---|---|---|---|
| trace() | Simple, handles any Python | Misses control flow | Static models, no if/for on tensors |
| script() | Captures control flow | Limited Python support | Dynamic shapes, conditionals |
ONNX Export
ONNX (Open Neural Network Exchange) is an open format for representing ML models, enabling deployment across different frameworks and hardware.
1import torch
2import torch.onnx
3
4model = MyModel()
5model.eval()
6dummy_input = torch.randn(1, 10)
7
8# Export to ONNX
9torch.onnx.export(
10 model,
11 dummy_input,
12 "model.onnx",
13 input_names=["input"],
14 output_names=["output"],
15 dynamic_axes={
16 "input": {0: "batch_size"}, # Variable batch size
17 "output": {0: "batch_size"},
18 },
19 opset_version=17, # Use recent opset for better op support
20)
21
22# Verify the exported model
23import onnx
24onnx_model = onnx.load("model.onnx")
25onnx.checker.check_model(onnx_model)
26
27# Run with ONNX Runtime (pip install onnxruntime)
28import onnxruntime as ort
29session = ort.InferenceSession("model.onnx")
30outputs = session.run(None, {"input": dummy_input.numpy()})Quantization
Quantization reduces model size and increases inference speed by using lower-precision arithmetic (int8 instead of float32).
1import torch
2from torch.quantization import quantize_dynamic, prepare, convert
3
4model = MyModel()
5model.eval()
6
7# Dynamic Quantization (easiest, for CPU)
8# Quantizes weights statically, activations dynamically
9quantized_model = quantize_dynamic(
10 model,
11 {torch.nn.Linear}, # Layers to quantize
12 dtype=torch.qint8
13)
14
15# Static Quantization (better performance, requires calibration)
16model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
17model_prepared = prepare(model)
18
19# Calibrate with representative data
20with torch.no_grad():
21 for data in calibration_loader:
22 model_prepared(data)
23
24model_quantized = convert(model_prepared)
25
26# Quantization-Aware Training (best accuracy)
27model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
28model_qat = prepare_qat(model.train())
29# Train as usual, then convert
30model_quantized = convert(model_qat.eval())
31
32# Compare sizes
33print(f"Original: {get_model_size(model):.2f} MB")
34print(f"Quantized: {get_model_size(quantized_model):.2f} MB")| Method | Accuracy | Speed | Effort |
|---|---|---|---|
| Dynamic | Good | 2-4x faster | One line of code |
| Static | Better | 2-4x faster | Needs calibration data |
| QAT | Best | 2-4x faster | Requires retraining |
Deployment Stack
Knowledge Check
Test your understanding of custom autograd functions with this comprehensive quiz.
Custom Autograd Knowledge Check
Question 1 of 10What are the two required static methods in a custom autograd Function?
Summary
In this section, we learned how to extend PyTorch's autograd system with custom operations:
| Concept | Key Points |
|---|---|
| torch.autograd.Function | Base class for custom differentiable operations |
| forward(ctx, ...) | Compute output, save tensors for backward |
| backward(ctx, grad_output) | Compute grad_input = grad_output × local_gradient |
| ctx.save_for_backward() | Store tensors needed in backward pass |
| .apply() | How to call custom functions (not direct call) |
| Straight-Through Estimator | Enable gradients for non-differentiable ops |
| Higher-order gradients | Use differentiable ops in backward() |
| gradcheck() | Verify gradients numerically |
Key Takeaways
- Custom functions teach PyTorch new operations by defining forward (computation) and backward (gradient) passes.
- The chain rule is central: multiply upstream gradient by local gradient to get downstream gradient.
- Save efficiently: only save what you need for backward, and consider saving outputs vs inputs based on the gradient formula.
- STE enables impossible operations: use different functions for forward and backward to handle non-differentiable ops.
- Always verify with gradcheck: numerical gradient checking catches bugs in your gradient formulas.
Exercises
Conceptual Questions
- Explain why
ctx.save_for_backward()only accepts tensors. How do you save non-tensor data? - Why does
backward()receivegrad_outputinstead of computing the gradient from scratch? - Describe a scenario where the Straight-Through Estimator produces biased gradients. Why might this still lead to successful training?
Coding Exercises
- GELU Activation: Implement the GELU activation function as a custom autograd function. GELU is defined as:Verify with
gradcheckand compare totorch.nn.functional.gelu. - Swish/SiLU Activation: Implement Swish: . Implement efficiently by saving the output in forward.
- Batch Normalization: Implement a simplified batch normalization as a custom function (without learnable parameters):Compute gradients with respect to input x.
- Differentiable Sorting: Implement a function that computes soft ranks of elements. Use a temperature parameter to control softness.
Challenge Exercise
Implement a Custom Attention Mechanism: Create a custom autograd function for scaled dot-product attention:
Requirements:
- Implement efficient forward pass that saves only necessary tensors
- Implement backward pass for all three inputs (Q, K, V)
- Support optional attention mask
- Verify gradients match PyTorch's built-in attention
1class ScaledDotProductAttention(torch.autograd.Function):
2 @staticmethod
3 def forward(ctx, Q, K, V, scale, mask=None):
4 # Q: (batch, heads, seq_len, d_k)
5 # K: (batch, heads, seq_len, d_k)
6 # V: (batch, heads, seq_len, d_v)
7
8 # Compute attention scores
9 scores = torch.matmul(Q, K.transpose(-2, -1)) * scale
10
11 if mask is not None:
12 scores = scores.masked_fill(mask == 0, float('-inf'))
13
14 attn_weights = torch.softmax(scores, dim=-1)
15 output = torch.matmul(attn_weights, V)
16
17 # Save what you need for backward
18 ctx.save_for_backward(Q, K, V, attn_weights)
19 ctx.scale = scale
20
21 return output
22
23 @staticmethod
24 def backward(ctx, grad_output):
25 Q, K, V, attn_weights = ctx.saved_tensors
26 scale = ctx.scale
27
28 # TODO: Implement gradients for Q, K, V
29 # Hint: Start with grad_V (easiest), then grad_attn_weights,
30 # then propagate through softmax, then to Q and K
31
32 grad_Q = ...
33 grad_K = ...
34 grad_V = ...
35
36 return grad_Q, grad_K, grad_V, None, NoneImplementation Hint
Congratulations on completing Chapter 4: PyTorch Fundamentals! You now have a deep understanding of tensors, operations, GPU computing, autograd, and custom functions. In the next chapter, we'll use these foundations to build neural network layers and explore the nn.Module API.