Learning Objectives
By the end of this section, you will be able to:
- Master broadcasting rules that automatically align tensor shapes for element-wise operations
- Understand matrix multiplication variants including mm, bmm, matmul, and when to use each
- Apply reduction operations like sum, mean, max along specific dimensions
- Use Einstein summation notation (einsum) for expressing complex tensor operations concisely
- Leverage linear algebra operations for solving systems, computing decompositions, and more
- Optimize performance through in-place operations and understanding memory patterns
Why This Matters: Neural networks are essentially sequences of tensor operations. A single forward pass through a transformer involves dozens of matrix multiplications, element-wise activations, and reductions. Mastering these operations is essential for understanding, implementing, and debugging deep learning models.
The Big Picture: Why Tensor Operations Matter
Every computation in deep learning can be expressed as a sequence of tensor operations. Consider what happens in a single attention layer:
- Matrix multiplication: , ,
- Batched matrix multiplication:
- Element-wise division:
- Reduction (softmax numerator):
- Reduction (softmax denominator):
- Broadcasting division: normalize across keys
- Batched matrix multiplication:
Understanding each operation deeply—not just what it computes, but how shapes transform and how memory is accessed—separates practitioners who debug efficiently from those who struggle with shape mismatches and mysterious errors.
Element-wise Operations
Element-wise operations apply a function independently to each element of one or more tensors. When operating on multiple tensors, they must have compatible shapes (same shape, or broadcastable).
Arithmetic Operations
Mathematical Functions
1x = torch.tensor([0.0, 1.0, 2.0, 3.0])
2
3# Exponentials and logarithms
4exp_x = torch.exp(x) # [1, e, e², e³]
5log_x = torch.log(x + 1) # Natural log (add 1 to avoid log(0))
6log2_x = torch.log2(x + 1) # Base-2 log
7log10_x = torch.log10(x + 1)
8
9# Trigonometric
10sin_x = torch.sin(x)
11cos_x = torch.cos(x)
12tan_x = torch.tan(x)
13
14# Inverse trigonometric
15asin_x = torch.asin(torch.clamp(x / 4, -1, 1))
16atan2_y_x = torch.atan2(y, x) # 2-argument arctangent
17
18# Hyperbolic
19tanh_x = torch.tanh(x) # Common activation function
20sinh_x = torch.sinh(x)
21
22# Powers and roots
23sqrt_x = torch.sqrt(x)
24rsqrt_x = torch.rsqrt(x + 1e-6) # 1/sqrt(x), add epsilon for stability
25
26# Absolute value and sign
27abs_x = torch.abs(x)
28sign_x = torch.sign(x)
29
30# Clipping and rounding
31clipped = torch.clamp(x, min=0.5, max=2.5)
32rounded = torch.round(x)
33floor_x = torch.floor(x)
34ceil_x = torch.ceil(x)Numerical Stability
torch.log() and torch.sqrt() can produce NaN or Inf for invalid inputs. Always add small epsilon values (1e-6) or use clipping to ensure valid domains. PyTorch provides torch.log1p(x) for stable log(1+x) when x is small.Activation Functions
While torch.nn provides activation modules, the functional versions are pure tensor operations:
1import torch.nn.functional as F
2
3x = torch.randn(4)
4
5# Common activations
6relu = F.relu(x) # max(0, x)
7leaky_relu = F.leaky_relu(x, 0.01) # max(0.01x, x)
8elu = F.elu(x) # x if x>0 else alpha*(exp(x)-1)
9gelu = F.gelu(x) # x * Φ(x), used in transformers
10silu = F.silu(x) # x * sigmoid(x), aka Swish
11
12# Sigmoid family
13sigmoid = torch.sigmoid(x) # 1 / (1 + exp(-x))
14tanh = torch.tanh(x) # (exp(x) - exp(-x)) / (exp(x) + exp(-x))
15softplus = F.softplus(x) # log(1 + exp(x)), smooth ReLU
16
17# Normalization (typically applied to vectors/matrices)
18softmax = F.softmax(x, dim=0) # exp(x) / sum(exp(x))
19log_softmax = F.log_softmax(x, dim=0) # More stable for loss computationBroadcasting: The Magic of Shape Alignment
Broadcasting is a powerful mechanism that allows PyTorch to perform operations on tensors with different shapes by automatically expanding the smaller tensor to match the larger one—without actually copying data.
The Broadcasting Rules
When operating on two tensors, PyTorch compares their shapes element-wise, starting from the trailing (rightmost) dimensions:
- Right-align the shapes and compare dimensions from right to left
- Two dimensions are compatible if they are equal, or one of them is 1
- If one tensor has fewer dimensions, prepend 1s to its shape
- The result shape is the element-wise maximum of each dimension
Broadcasting is conceptual—no data is copied. PyTorch uses stride tricks to make the smaller tensor appear larger, reading the same values multiple times.
Broadcasting Gotchas
Interactive: Broadcasting Visualizer
Explore how broadcasting works by selecting different shape combinations. See how dimensions are aligned, which are broadcasted, and what the resulting shape will be.
Broadcasting Visualizer
Understand how PyTorch automatically expands tensors for element-wise operations
Scalar broadcasts to every element of the vector
Broadcasting Rules
- 1Right-align shapes and prepend 1s to the shorter shape
- 2For each dimension, shapes are compatible if they are equal or one of them is 1
- 3The result shape is the maximum of each dimension
# PyTorch broadcasts automatically
a = torch.randn(4, 1)
b = torch.randn(1, 3)
c = a + b # Shape: (4, 3)Step-by-Step
Output Visualization
Quick Check
What is the result of broadcasting shapes (4, 1, 3) and (5, 3)?
Matrix Multiplication and Variants
Matrix multiplication is the workhorse of deep learning. PyTorch provides several functions for different use cases, each with specific shape requirements.
The Core Operations
| Function | Input Shapes | Output Shape | Use Case |
|---|---|---|---|
| torch.mm(A, B) | (m, k), (k, n) | (m, n) | 2D matrix multiply |
| torch.mv(A, v) | (m, n), (n,) | (m,) | Matrix-vector multiply |
| torch.dot(a, b) | (n,), (n,) | scalar | Vector dot product |
| torch.bmm(A, B) | (b, m, k), (b, k, n) | (b, m, n) | Batched matmul |
| torch.matmul(A, B) | Various | Broadcasting | General matmul with broadcasting |
| A @ B | Various | Broadcasting | Operator syntax for matmul |
Understanding @ (matmul)
The @ operator and torch.matmul() are the most flexible, supporting:
- 1D × 1D: dot product (returns scalar)
- 2D × 2D: standard matrix multiplication
- 1D × 2D: vector treated as row, multiply, then squeeze
- 2D × 1D: multiply, result is vector
- ND × ND: batched matmul with broadcasting on batch dimensions
mm vs matmul
torch.mm() when you specifically need 2D × 2D and want to catch dimension errors. Use @ or torch.matmul() for flexibility and broadcasting support.Interactive: Matrix Multiplication
Visualize different forms of matrix multiplication. See how elements combine, watch the computation step by step, and understand the shape transformations.
Matrix Multiplication Visualizer
Explore different forms of matrix and tensor multiplication in PyTorch
Classic matrix multiplication: (m, k) × (k, n) → (m, n)
torch.mm(A, B) or A @ BMathematical Formulation
Reduction Operations
Reduction operations aggregate tensor values along one or more dimensions, producing a smaller tensor (or scalar). They are fundamental for computing statistics, losses, and attention weights.
Common Reductions
Multiple Dimension Reduction
1x = torch.randn(2, 3, 4, 5) # Shape: (2, 3, 4, 5)
2
3# Reduce multiple dimensions at once
4spatial_mean = x.mean(dim=(2, 3)) # Shape: (2, 3)
5batch_channel = x.sum(dim=(0, 1)) # Shape: (4, 5)
6
7# All dimensions except one
8# Method 1: List all dims
9global_except_batch = x.mean(dim=(1, 2, 3)) # Shape: (2,)
10
11# Method 2: Use flatten and mean
12global_except_batch = x.flatten(1).mean(dim=1) # Shape: (2,)
13
14# Reduce to just batch dim (useful for per-sample losses)
15per_sample = x.sum(dim=list(range(1, x.ndim))) # Shape: (2,)Softmax Reduction Pattern
F.softmax(x, dim=-1) handles this efficiently.Interactive: Reduction Operations
Experiment with different reduction operations and dimensions. See how the output shape changes and watch the values being aggregated.
Reduction Operations
Aggregate tensor values along dimensions
Add all elements along dimension
Dimension Guide
Collapse entire tensor to scalar
Reduce along rows (vertically)
Reduce along cols (horizontally)
# Shape: (3, 4) -> (3,)
x = torch.tensor([[5, 4, 8, 7], [9, 1, 3, 3], [7, 2, 5, 9]])
result = x.sum(dim=1)
# result = [24,16,23]Quick Check
What is the shape of x.mean(dim=1, keepdim=True) where x has shape (8, 16, 32)?
Comparison and Logical Operations
Comparison operations return boolean tensors, while logical operations combine them. These are essential for masking, filtering, and conditional computation.
Masking in Practice
1# Attention masking
2seq_len = 10
3# Create causal mask (lower triangular)
4causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
5
6# Apply mask to attention scores
7scores = torch.randn(seq_len, seq_len)
8masked_scores = scores.masked_fill(~causal_mask, float('-inf'))
9# Positions where mask is False become -inf
10# After softmax, these become 0
11
12# Padding mask (variable length sequences)
13lengths = torch.tensor([7, 5, 10]) # Actual lengths in batch
14max_len = 10
15# Create mask: True for valid positions
16padding_mask = torch.arange(max_len).expand(3, -1) < lengths.unsqueeze(1)
17# Shape: (3, 10) with True for valid tokens
18
19# Using masks for selective operations
20x = torch.randn(3, 10, 64) # (batch, seq, features)
21# Zero out padded positions
22x = x * padding_mask.unsqueeze(-1) # Broadcasting (3, 10, 1)Einstein Summation: The Universal Notation
Einstein summation notation (einsum) is a powerful and concise way to express many tensor operations. It uses index labels to describe how dimensions combine and contract.
The Einsum Rules
An einsum expression like "ij,jk->ik" means:
- Each letter (i, j, k) represents a dimension index
- Repeated indices across inputs are summed (contracted)
- Output indices after
->define the result shape - Indices not in output are implicitly summed over
Common Patterns
| Einsum | Operation | Equivalent |
|---|---|---|
| 'ij->ji' | Transpose | A.T |
| 'ii->i' | Diagonal | torch.diag(A) |
| 'ii->' | Trace | torch.trace(A) |
| 'ij->' | Sum all | A.sum() |
| 'ij->i' | Sum rows | A.sum(dim=1) |
| 'i,i->' | Dot product | torch.dot(a, b) |
| 'i,j->ij' | Outer product | torch.outer(a, b) |
| 'ij,jk->ik' | Matrix multiply | A @ B |
| 'bij,bjk->bik' | Batch matmul | torch.bmm(A, B) |
| 'bhqd,bhkd->bhqk' | Attention scores | Q @ K.transpose(-2,-1) |
When to Use Einsum
@—it's equally efficient and more readable. Einsum shines for attention mechanisms, tensor decompositions, and custom contractions.Interactive: Einsum Explorer
Explore Einstein summation notation interactively. Select from common patterns, see how indices map to operations, and understand what gets contracted versus preserved.
Einstein Summation Explorer
Master the universal notation for tensor operations
Einsum Rules
- • Each letter represents a dimension index
- • Repeated indices in inputs are summed (contracted)
- • Output indices define the result shape
- • Free indices (appearing once) are preserved in output
- • Omitting an index from output = sum over that dimension
Transpose
basicShapes
(m, n)(n, m)PyTorch Code
torch.einsum("ij->ji", ...)A.T or A.transpose(0, 1)Description
Swap dimensions i and j
Use Case
Preparing matrices for multiplication
Try Your Own
Linear Algebra Operations
PyTorch includes a rich set of linear algebra operations in torch.linalg, essential for solving systems, computing decompositions, and understanding data structure.
Matrix Properties and Norms
1import torch
2
3A = torch.randn(3, 3)
4
5# Matrix properties
6det = torch.linalg.det(A) # Determinant
7rank = torch.linalg.matrix_rank(A) # Rank
8trace = torch.trace(A) # Sum of diagonal
9
10# Norms
11vec = torch.randn(5)
12l1_norm = torch.linalg.norm(vec, ord=1) # Sum of absolute values
13l2_norm = torch.linalg.norm(vec, ord=2) # Euclidean norm
14linf_norm = torch.linalg.norm(vec, ord=float('inf')) # Max absolute
15
16# Matrix norms
17fro_norm = torch.linalg.norm(A, ord='fro') # Frobenius (L2 on flattened)
18nuc_norm = torch.linalg.norm(A, ord='nuc') # Nuclear (sum of singular values)
19spectral = torch.linalg.norm(A, ord=2) # Spectral (largest singular value)
20
21# Condition number (sensitivity to perturbation)
22cond = torch.linalg.cond(A)Decompositions
1import torch
2
3A = torch.randn(4, 3)
4B = torch.randn(3, 3)
5
6# SVD: A = U @ S @ Vh
7U, S, Vh = torch.linalg.svd(A)
8# U: (4, 4), S: (3,), Vh: (3, 3)
9
10# Eigendecomposition (square matrices)
11eigenvalues, eigenvectors = torch.linalg.eig(B)
12# For symmetric matrices, use eigh (real eigenvalues)
13sym = B @ B.T
14eigenvalues, eigenvectors = torch.linalg.eigh(sym)
15
16# QR decomposition: A = Q @ R
17Q, R = torch.linalg.qr(A)
18# Q: orthogonal, R: upper triangular
19
20# Cholesky decomposition (positive definite)
21pos_def = sym + 3 * torch.eye(3) # Make positive definite
22L = torch.linalg.cholesky(pos_def)
23# L @ L.T = pos_def
24
25# LU decomposition
26P, L, U = torch.linalg.lu(B)
27# P @ L @ U = BSolving Linear Systems
1import torch
2
3# Solve Ax = b
4A = torch.randn(3, 3)
5b = torch.randn(3, 1)
6
7# Direct solve (uses LU decomposition internally)
8x = torch.linalg.solve(A, b) # x such that Ax = b
9
10# Least squares (overdetermined system)
11A = torch.randn(5, 3) # More equations than unknowns
12b = torch.randn(5, 1)
13x, residuals, rank, s = torch.linalg.lstsq(A, b)
14
15# Matrix inverse (avoid when possible - solve is more stable)
16A_inv = torch.linalg.inv(A[:3, :3]) # Square matrix required
17
18# Pseudoinverse (works for any shape)
19A_pinv = torch.linalg.pinv(A)
20
21# Batched solve
22A_batch = torch.randn(10, 3, 3) # 10 systems
23b_batch = torch.randn(10, 3, 1)
24x_batch = torch.linalg.solve(A_batch, b_batch) # Solves all 10Numerical Stability
torch.linalg.inv) is rarely the best approach. Use torch.linalg.solve(A, b) instead of inv(A) @ b—it's more numerically stable and often faster.In-place Operations
In-place operations modify a tensor directly without creating a new one. They're identified by a trailing underscore (e.g., add_()).
1import torch
2
3x = torch.tensor([1.0, 2.0, 3.0])
4
5# In-place operations (note the underscore)
6x.add_(1) # x is now [2, 3, 4]
7x.mul_(2) # x is now [4, 6, 8]
8x.fill_(0) # x is now [0, 0, 0]
9x.zero_() # Same as fill_(0)
10
11# Clamp in-place
12x = torch.randn(5)
13x.clamp_(min=0) # ReLU in-place
14
15# In-place random fill
16x.uniform_(0, 1) # Fill with uniform random
17x.normal_(0, 1) # Fill with normal random
18
19# Copy in-place
20src = torch.tensor([1., 2., 3.])
21x.copy_(src) # Copy src into x
22
23# Indexing with in-place modification
24y = torch.zeros(5)
25y[2:4] = torch.tensor([7., 8.]) # Modifies y in-placeIn-place and Autograd
1# This will cause an error during backward:
2x = torch.randn(3, requires_grad=True)
3x.add_(1) # RuntimeError!
4
5# Safe alternative:
6x = torch.randn(3, requires_grad=True)
7y = x + 1 # Creates new tensor, grad graph preserved
8
9# In-place is OK for non-leaf tensors in some cases,
10# but it's safest to avoid during trainingPerformance Considerations
Understanding how tensor operations use memory and compute resources helps you write efficient deep learning code.
Memory Layout and Contiguity
1import torch
2
3x = torch.randn(1000, 1000)
4
5# Transpose creates a non-contiguous view
6y = x.T
7print(y.is_contiguous()) # False
8
9# Some operations require contiguous memory
10# contiguous() copies data into contiguous layout
11y_contig = y.contiguous()
12
13# reshape() vs view()
14# view() requires contiguous, reshape() handles either
15z = x.T.reshape(-1) # Works (copies internally)
16# z = x.T.view(-1) # Error: not contiguous
17
18# Tip: Check contiguity before performance-critical loops
19if not x.is_contiguous():
20 x = x.contiguous()Fused Operations
Some operations are optimized to execute together, avoiding intermediate memory allocation:
1# Fused multiply-add (more efficient)
2result = torch.addmm(bias, input, weight) # bias + input @ weight
3
4# Fused operations in F.linear
5output = F.linear(input, weight, bias) # Optimized internally
6
7# Fused attention (PyTorch 2.0+)
8from torch.nn.functional import scaled_dot_product_attention
9output = scaled_dot_product_attention(Q, K, V) # Fused, memory efficient
10
11# torch.compile for automatic fusion (PyTorch 2.0+)
12@torch.compile
13def my_function(x, y):
14 return (x * y).sum(dim=-1).mean()
15# Compiler fuses operations automaticallyMemory-Efficient Patterns
1import torch
2
3# Bad: Creates intermediate tensors
4def inefficient(x, y, z):
5 temp1 = x * y # Allocates temp1
6 temp2 = temp1 + z # Allocates temp2
7 return temp2.sum()
8
9# Better: Use in-place where safe (but not with autograd)
10def efficient_no_grad(x, y, z):
11 with torch.no_grad():
12 x.mul_(y)
13 x.add_(z)
14 return x.sum()
15
16# Best: Let torch.compile optimize
17@torch.compile
18def optimized(x, y, z):
19 return ((x * y) + z).sum() # Compiler fuses
20
21# Gradient checkpointing for large models
22from torch.utils.checkpoint import checkpoint
23# Recomputes activations during backward instead of storing
24output = checkpoint(layer, input) # Trades compute for memorytorch.compile: The PyTorch 2.0 Compiler
torch.compile is PyTorch's JIT compiler that can dramatically speed up your models by automatically optimizing and fusing operations.
1import torch
2
3# Basic usage - wrap your model or function
4model = MyModel()
5compiled_model = torch.compile(model)
6
7# Or use as a decorator
8@torch.compile
9def my_function(x, y):
10 return (x.matmul(y) + x).relu()
11
12# Compilation modes
13model_default = torch.compile(model) # Balanced
14model_fast = torch.compile(model, mode="reduce-overhead") # Minimize latency
15model_max = torch.compile(model, mode="max-autotune") # Best perf, slow compile
16
17# Disable for debugging
18model_debug = torch.compile(model, disable=True)
19
20# Full options for production
21compiled = torch.compile(
22 model,
23 mode="max-autotune",
24 fullgraph=True, # Error if can't compile entire graph
25 dynamic=False, # Static shapes (faster)
26)| Mode | Compile Time | Runtime | Use Case |
|---|---|---|---|
| default | Medium | Good | General use, first try this |
| reduce-overhead | Fast | Good | Many small operations, inference |
| max-autotune | Slow | Best | Production training, benchmarking |
When torch.compile May Not Help
- Dynamic shapes: Variable batch sizes or sequence lengths trigger recompilation
- Python control flow: if/else based on tensor values breaks the graph
- Small models: Compilation overhead may exceed speedup
- Custom C++ ops: May not be fusible by the compiler
Profiling
torch.profiler to identify bottlenecks. GPU operations are asynchronous, so wall-clock time can be misleading. The profiler shows actual GPU utilization and memory usage.Knowledge Check
Test your understanding of tensor operations with this comprehensive quiz covering broadcasting, matrix multiplication, einsum, and reductions.
Tensor Operations Quiz
What is the result shape when adding tensors of shapes (3, 1) and (1, 4)?
Summary
Tensor operations are the building blocks of all deep learning computation. Here's what we covered:
| Category | Key Operations | When to Use |
|---|---|---|
| Element-wise | +, *, exp, relu, tanh | Activations, normalization, feature transforms |
| Broadcasting | Automatic shape alignment | Per-channel ops, batched operations, masking |
| Matrix Multiply | @, mm, bmm, matmul | Linear layers, attention, projections |
| Reductions | sum, mean, max, softmax | Loss computation, pooling, normalization |
| Comparison | ==, <, >, where, masked_fill | Masking, conditional logic, filtering |
| Einsum | torch.einsum('...',) | Complex contractions, custom operations |
| Linear Algebra | solve, svd, eig, norm | Optimization, analysis, preprocessing |
Key Takeaways
- Broadcasting aligns shapes by right-padding with 1s and expanding. Understand it deeply to avoid bugs.
- Use @ for matrix multiplication—it handles broadcasting and is the most Pythonic syntax.
- Einsum is powerful for expressing complex tensor operations concisely, especially in attention mechanisms.
- Pay attention to dimension ordering in reductions—dim=0 is usually batch, dim=-1 is features.
- Avoid in-place operations during training to prevent autograd issues.
Exercises
Conceptual Questions
- Explain why
torch.mm(A, B)fails butA @ Bworks when A has shape (1, 3, 4) and B has shape (4, 5). - Why does einsum
'ij,jk->ik'represent matrix multiplication? What would'ij,jk->ijk'compute? - How would you use broadcasting to subtract the mean from each row of a matrix without explicitly reshaping?
Coding Exercises
- Broadcasting Practice: Given a batch of images with shape (32, 3, 224, 224), write code to normalize each channel independently using per-channel mean and std vectors.
- Attention Without Loops: Implement scaled dot-product attention using only tensor operations (no loops). Input shapes: Q, K, V all (batch, seq_len, d_model).
- Einsum Mastery: Rewrite these operations using einsum:
- Batch matrix trace: sum of diagonals for each matrix in a batch (B, N, N) → (B,)
- Batched outer product: (B, M) and (B, N) → (B, M, N)
- Multi-head attention scores: Q(B,H,Q,D) × K(B,H,K,D)^T → (B,H,Q,K)
- Memory Investigation: Create a 3D tensor, perform various reshape and transpose operations, and determine which create views vs copies by checking
data_ptr()andis_contiguous().
Challenge Exercise
Implement Layer Normalization from Scratch: Using only tensor operations covered in this section, implement layer normalization that normalizes over the last D dimensions with learnable affine parameters.
1def layer_norm(x, normalized_shape, weight, bias, eps=1e-5):
2 """
3 Args:
4 x: Input tensor of any shape
5 normalized_shape: tuple of ints (last D dimensions to normalize)
6 weight: Scale parameter (same shape as normalized_shape)
7 bias: Shift parameter (same shape as normalized_shape)
8 eps: Small constant for numerical stability
9 Returns:
10 Normalized tensor with same shape as x
11 """
12 # Your implementation here
13 # Hint: Use reduction over last D dims, broadcasting, and element-wise ops
14 passIn the next section, we'll dive into advanced indexing, slicing, and reshaping operations that give you precise control over tensor data.