Chapter 4
35 min read
Section 26 of 178

Tensor Operations Deep Dive

PyTorch Fundamentals

Learning Objectives

By the end of this section, you will be able to:

  1. Master broadcasting rules that automatically align tensor shapes for element-wise operations
  2. Understand matrix multiplication variants including mm, bmm, matmul, and when to use each
  3. Apply reduction operations like sum, mean, max along specific dimensions
  4. Use Einstein summation notation (einsum) for expressing complex tensor operations concisely
  5. Leverage linear algebra operations for solving systems, computing decompositions, and more
  6. Optimize performance through in-place operations and understanding memory patterns
Why This Matters: Neural networks are essentially sequences of tensor operations. A single forward pass through a transformer involves dozens of matrix multiplications, element-wise activations, and reductions. Mastering these operations is essential for understanding, implementing, and debugging deep learning models.

The Big Picture: Why Tensor Operations Matter

Every computation in deep learning can be expressed as a sequence of tensor operations. Consider what happens in a single attention layer:

  1. Matrix multiplication: Q=XWQQ = XW_Q, K=XWKK = XW_K, V=XWVV = XW_V
  2. Batched matrix multiplication: scores=QKT\text{scores} = QK^T
  3. Element-wise division: scaled=scores/dk\text{scaled} = \text{scores} / \sqrt{d_k}
  4. Reduction (softmax numerator): exp(scaled)\exp(\text{scaled})
  5. Reduction (softmax denominator): kexp(scaled)\sum_k \exp(\text{scaled})
  6. Broadcasting division: normalize across keys
  7. Batched matrix multiplication: output=attnV\text{output} = \text{attn} \cdot V

Understanding each operation deeply—not just what it computes, but how shapes transform and how memory is accessed—separates practitioners who debug efficiently from those who struggle with shape mismatches and mysterious errors.


Element-wise Operations

Element-wise operations apply a function independently to each element of one or more tensors. When operating on multiple tensors, they must have compatible shapes (same shape, or broadcastable).

Arithmetic Operations

Element-wise Arithmetic Operations
🐍element_wise_arithmetic.py
7Addition (+)

Adds corresponding elements. PyTorch operator overloading makes syntax intuitive.

10Multiplication (*)

Element-wise product, NOT matrix multiplication. For matrices, use @ or torch.mm().

12Power (**)

Each element raised to the power. a ** 2 squares each element.

16Floor Division (//)

Integer division, rounding toward negative infinity. Different from truncation for negative numbers.

25Scalar Broadcasting

Scalars automatically broadcast to match tensor shape. No explicit expansion needed.

22 lines without explanation
1import torch
2
3a = torch.tensor([1.0, 2.0, 3.0, 4.0])
4b = torch.tensor([2.0, 2.0, 2.0, 2.0])
5
6# Basic arithmetic
7add = a + b           # [3, 4, 5, 6]
8sub = a - b           # [-1, 0, 1, 2]
9mul = a * b           # [2, 4, 6, 8]
10div = a / b           # [0.5, 1.0, 1.5, 2.0]
11pow_result = a ** b   # [1, 4, 9, 16]
12mod = a % 2           # [1, 0, 1, 0]
13
14# Floor and true division
15floor_div = a // b    # [0, 1, 1, 2] (integer part)
16true_div = a / b      # [0.5, 1.0, 1.5, 2.0]
17
18# Negation
19neg = -a              # [-1, -2, -3, -4]
20
21# Method equivalents
22add2 = torch.add(a, b)
23mul2 = torch.mul(a, b)
24
25# Scalar operations broadcast automatically
26scaled = a * 10       # [10, 20, 30, 40]
27shifted = a + 5       # [6, 7, 8, 9]

Mathematical Functions

🐍math_functions.py
1x = torch.tensor([0.0, 1.0, 2.0, 3.0])
2
3# Exponentials and logarithms
4exp_x = torch.exp(x)        # [1, e, e², e³]
5log_x = torch.log(x + 1)    # Natural log (add 1 to avoid log(0))
6log2_x = torch.log2(x + 1)  # Base-2 log
7log10_x = torch.log10(x + 1)
8
9# Trigonometric
10sin_x = torch.sin(x)
11cos_x = torch.cos(x)
12tan_x = torch.tan(x)
13
14# Inverse trigonometric
15asin_x = torch.asin(torch.clamp(x / 4, -1, 1))
16atan2_y_x = torch.atan2(y, x)  # 2-argument arctangent
17
18# Hyperbolic
19tanh_x = torch.tanh(x)      # Common activation function
20sinh_x = torch.sinh(x)
21
22# Powers and roots
23sqrt_x = torch.sqrt(x)
24rsqrt_x = torch.rsqrt(x + 1e-6)  # 1/sqrt(x), add epsilon for stability
25
26# Absolute value and sign
27abs_x = torch.abs(x)
28sign_x = torch.sign(x)
29
30# Clipping and rounding
31clipped = torch.clamp(x, min=0.5, max=2.5)
32rounded = torch.round(x)
33floor_x = torch.floor(x)
34ceil_x = torch.ceil(x)

Numerical Stability

Functions like torch.log() and torch.sqrt() can produce NaN or Inf for invalid inputs. Always add small epsilon values (1e-6) or use clipping to ensure valid domains. PyTorch provides torch.log1p(x) for stable log(1+x) when x is small.

Activation Functions

While torch.nn provides activation modules, the functional versions are pure tensor operations:

🐍activations.py
1import torch.nn.functional as F
2
3x = torch.randn(4)
4
5# Common activations
6relu = F.relu(x)                    # max(0, x)
7leaky_relu = F.leaky_relu(x, 0.01)  # max(0.01x, x)
8elu = F.elu(x)                      # x if x>0 else alpha*(exp(x)-1)
9gelu = F.gelu(x)                    # x * Φ(x), used in transformers
10silu = F.silu(x)                    # x * sigmoid(x), aka Swish
11
12# Sigmoid family
13sigmoid = torch.sigmoid(x)          # 1 / (1 + exp(-x))
14tanh = torch.tanh(x)                # (exp(x) - exp(-x)) / (exp(x) + exp(-x))
15softplus = F.softplus(x)            # log(1 + exp(x)), smooth ReLU
16
17# Normalization (typically applied to vectors/matrices)
18softmax = F.softmax(x, dim=0)       # exp(x) / sum(exp(x))
19log_softmax = F.log_softmax(x, dim=0)  # More stable for loss computation

Broadcasting: The Magic of Shape Alignment

Broadcasting is a powerful mechanism that allows PyTorch to perform operations on tensors with different shapes by automatically expanding the smaller tensor to match the larger one—without actually copying data.

The Broadcasting Rules

When operating on two tensors, PyTorch compares their shapes element-wise, starting from the trailing (rightmost) dimensions:

  1. Right-align the shapes and compare dimensions from right to left
  2. Two dimensions are compatible if they are equal, or one of them is 1
  3. If one tensor has fewer dimensions, prepend 1s to its shape
  4. The result shape is the element-wise maximum of each dimension
Shape A: (2,3,4)Shape B: (3,1)    Result: (2,3,4)\text{Shape A: } (2, 3, 4) \quad \text{Shape B: } (3, 1) \implies \text{Result: } (2, 3, 4)

Broadcasting is conceptual—no data is copied. PyTorch uses stride tricks to make the smaller tensor appear larger, reading the same values multiple times.

Broadcasting Examples
🐍broadcasting.py
4Scalar Broadcasting

A scalar (shape ()) broadcasts to any shape. The single value is used for every element.

10Vector Along Axis

Shape (4,) becomes (1, 4), then broadcasts to (3, 4). Each row gets the same bias vector added.

15Outer Sum Pattern

(3, 1) + (1, 4) = (3, 4). This creates an outer sum where result[i,j] = a[i] + b[j].

EXAMPLE
Useful for computing distance matrices
20Per-Channel Normalization

(3, 1, 1) broadcasts across batch (32) and spatial dims (224, 224). Each channel gets its own mean subtracted.

26Attention Masking

Mask shape (1, 1, Q, K) broadcasts to all batches and heads. Common pattern in transformers.

31Explicit Expand

expand() returns a view with virtual copies. No memory is allocated; it just changes strides.

24 lines without explanation
1import torch
2
3# Example 1: Scalar broadcasts everywhere
4a = torch.randn(3, 4)
5b = torch.tensor(2.0)
6c = a + b  # Shape: (3, 4)
7
8# Example 2: Vector broadcasts along first dimension
9a = torch.randn(3, 4)  # (3, 4)
10b = torch.randn(4)     # (4,) -> broadcasts to (3, 4)
11c = a + b              # Shape: (3, 4)
12
13# Example 3: Classic outer product pattern
14a = torch.randn(3, 1)  # Column vector
15b = torch.randn(1, 4)  # Row vector
16c = a + b              # Shape: (3, 4) - outer sum!
17
18# Example 4: Batch normalization pattern
19batch = torch.randn(32, 3, 224, 224)  # (N, C, H, W)
20mean = torch.randn(3, 1, 1)            # Per-channel mean
21normalized = batch - mean              # (32, 3, 224, 224)
22
23# Example 5: Attention score masking
24scores = torch.randn(8, 12, 64, 64)   # (B, H, Q, K)
25mask = torch.randn(1, 1, 64, 64)      # Broadcast across batch and heads
26masked = scores + mask                 # (8, 12, 64, 64)
27
28# Explicit broadcasting with expand (no memory copy)
29x = torch.randn(1, 4)
30expanded = x.expand(3, 4)  # Shape (3, 4), shares memory with x

Broadcasting Gotchas

Broadcasting can hide bugs. If shapes accidentally align in unexpected ways, you may get wrong results without errors. Always verify shapes match your expectations, especially when debugging.

Interactive: Broadcasting Visualizer

Explore how broadcasting works by selecting different shape combinations. See how dimensions are aligned, which are broadcasted, and what the resulting shape will be.

Broadcasting Visualizer

Understand how PyTorch automatically expands tensors for element-wise operations

Scalar broadcasts to every element of the vector

Broadcasting Rules

  1. 1Right-align shapes and prepend 1s to the shorter shape
  2. 2For each dimension, shapes are compatible if they are equal or one of them is 1
  3. 3The result shape is the maximum of each dimension
# PyTorch broadcasts automatically a = torch.randn(4, 1) b = torch.randn(1, 3) c = a + b # Shape: (4, 3)
Tensor A
(1)
+
Tensor B
(4)
Result Shape
(4)

Step-by-Step

Dim 0: 1 broadcasts to 4 -> 4

Output Visualization

Quick Check

What is the result of broadcasting shapes (4, 1, 3) and (5, 3)?


Matrix Multiplication and Variants

Matrix multiplication is the workhorse of deep learning. PyTorch provides several functions for different use cases, each with specific shape requirements.

The Core Operations

FunctionInput ShapesOutput ShapeUse Case
torch.mm(A, B)(m, k), (k, n)(m, n)2D matrix multiply
torch.mv(A, v)(m, n), (n,)(m,)Matrix-vector multiply
torch.dot(a, b)(n,), (n,)scalarVector dot product
torch.bmm(A, B)(b, m, k), (b, k, n)(b, m, n)Batched matmul
torch.matmul(A, B)VariousBroadcastingGeneral matmul with broadcasting
A @ BVariousBroadcastingOperator syntax for matmul

Understanding @ (matmul)

The @ operator and torch.matmul() are the most flexible, supporting:

  • 1D × 1D: dot product (returns scalar)
  • 2D × 2D: standard matrix multiplication
  • 1D × 2D: vector treated as row, multiply, then squeeze
  • 2D × 1D: multiply, result is vector
  • ND × ND: batched matmul with broadcasting on batch dimensions
Matrix Multiplication Variants
🐍matrix_multiplication.py
6Standard MatMul

The inner dimensions (4) must match. Result has outer dimensions (3, 5).

11Batched MatMul

When both inputs have batch dimensions, PyTorch performs independent matrix multiplies for each batch element.

16Batch Broadcasting

Like element-wise ops, batch dimensions can broadcast. Here 1 broadcasts to 5.

21Linear Layer Pattern

y = xW + b is the core computation of nn.Linear. Note W is (in, out) not (out, in).

28Attention Scores

QK^T requires transposing K's last two dims. Result is (B, H, Q, K) attention scores.

24 lines without explanation
1import torch
2
3# 2D Matrix Multiplication: (m, k) @ (k, n) -> (m, n)
4A = torch.randn(3, 4)  # 3x4 matrix
5B = torch.randn(4, 5)  # 4x5 matrix
6C = A @ B              # 3x5 matrix
7
8# Batched Matrix Multiplication
9# All batch dimensions must match or broadcast
10batch_A = torch.randn(32, 8, 64, 128)   # (B, H, Q, D)
11batch_B = torch.randn(32, 8, 128, 64)   # (B, H, D, K)
12batch_C = batch_A @ batch_B              # (32, 8, 64, 64)
13
14# Broadcasting in batch dimensions
15A = torch.randn(1, 3, 4)    # (1, 3, 4)
16B = torch.randn(5, 4, 2)    # (5, 4, 2)
17C = A @ B                    # (5, 3, 2) - batch dim broadcasts
18
19# Linear layer forward pass
20x = torch.randn(32, 784)     # Batch of 32, 784 features
21W = torch.randn(784, 256)    # Weight matrix
22b = torch.randn(256)         # Bias vector
23y = x @ W + b                # (32, 256)
24
25# Attention: QK^T computation
26Q = torch.randn(8, 12, 64, 32)  # (B, H, Q, D)
27K = torch.randn(8, 12, 64, 32)  # (B, H, K, D)
28# K.transpose(-2, -1) -> (8, 12, 32, 64)
29scores = Q @ K.transpose(-2, -1)  # (8, 12, 64, 64)

mm vs matmul

Use torch.mm() when you specifically need 2D × 2D and want to catch dimension errors. Use @ or torch.matmul() for flexibility and broadcasting support.

Interactive: Matrix Multiplication

Visualize different forms of matrix multiplication. See how elements combine, watch the computation step by step, and understand the shape transformations.

Matrix Multiplication Visualizer

Explore different forms of matrix and tensor multiplication in PyTorch

Classic matrix multiplication: (m, k) × (k, n) → (m, n)

torch.mm(A, B) or A @ B
A: (3, 4)×B: (4, 2)Result: (3, 2)
Matrix A
2.0
2.0
2.0
2.0
3.0
1.0
3.0
1.0
2.0
2.0
3.0
3.0
×
Matrix B
4.0
2.0
2.0
1.0
5.0
3.0
1.0
3.0
=
Result
24.0
18.0
30.0
19.0
30.0
24.0
Step:1 / 6

Mathematical Formulation

C[i,j] = Σₖ A[i,k] × B[k,j]

Reduction Operations

Reduction operations aggregate tensor values along one or more dimensions, producing a smaller tensor (or scalar). They are fundamental for computing statistics, losses, and attention weights.

Common Reductions

Reduction Operations
🐍reductions.py
7Global Sum

No dim argument sums all elements to a scalar.

8Sum Along Dimension

dim=1 sums along columns (horizontally). Each row becomes a single value.

17Max Returns Two Tensors

When dim is specified, max/min return (values, indices). Use .values or unpack.

28LogSumExp

Computes log(sum(exp(x))) in a numerically stable way. Used in softmax and loss functions.

31keepdim=True

Preserves the reduced dimension as size 1. Essential for broadcasting the result back.

30 lines without explanation
1import torch
2
3x = torch.tensor([[1., 2., 3.],
4                   [4., 5., 6.]])  # Shape: (2, 3)
5
6# Sum
7total = x.sum()              # Scalar: 21
8row_sums = x.sum(dim=1)      # Shape: (2,) -> [6, 15]
9col_sums = x.sum(dim=0)      # Shape: (3,) -> [5, 7, 9]
10
11# Mean
12mean_all = x.mean()          # Scalar: 3.5
13mean_rows = x.mean(dim=1)    # Shape: (2,) -> [2, 5]
14
15# Max and Argmax
16max_val = x.max()            # Scalar: 6
17max_per_row, indices = x.max(dim=1)  # Values and indices
18# max_per_row: [3, 6], indices: [2, 2]
19
20# Min
21min_val, min_idx = x.min(dim=0)  # Min per column
22
23# Product
24prod = x.prod()              # Scalar: 720 (1*2*3*4*5*6)
25
26# Standard deviation and Variance
27std = x.std()                # Standard deviation
28var = x.var()                # Variance
29
30# LogSumExp (numerically stable)
31logsumexp = x.logsumexp(dim=1)  # log(sum(exp(x))) per row
32
33# Keeping dimensions with keepdim=True
34row_mean = x.mean(dim=1, keepdim=True)  # Shape: (2, 1)
35centered = x - row_mean                  # Broadcasting works!

Multiple Dimension Reduction

🐍multi_dim_reduction.py
1x = torch.randn(2, 3, 4, 5)  # Shape: (2, 3, 4, 5)
2
3# Reduce multiple dimensions at once
4spatial_mean = x.mean(dim=(2, 3))  # Shape: (2, 3)
5batch_channel = x.sum(dim=(0, 1))  # Shape: (4, 5)
6
7# All dimensions except one
8# Method 1: List all dims
9global_except_batch = x.mean(dim=(1, 2, 3))  # Shape: (2,)
10
11# Method 2: Use flatten and mean
12global_except_batch = x.flatten(1).mean(dim=1)  # Shape: (2,)
13
14# Reduce to just batch dim (useful for per-sample losses)
15per_sample = x.sum(dim=list(range(1, x.ndim)))  # Shape: (2,)

Softmax Reduction Pattern

Softmax combines several operations: subtract max (for stability), exponentiate, sum, divide. In PyTorch: F.softmax(x, dim=-1) handles this efficiently.

Interactive: Reduction Operations

Experiment with different reduction operations and dimensions. See how the output shape changes and watch the values being aggregated.

Reduction Operations

Aggregate tensor values along dimensions

Sum

Add all elements along dimension

Formula: Σᵢ xᵢ
PyTorch: tensor.sum(dim=d)
Rows:3
Cols:4
(3, 4)(3,)
Input (3×4)
5
4
8
7
9
1
3
3
7
2
5
9
Σ
dim=1
Result
[24,16,23]

Dimension Guide

All dims

Collapse entire tensor to scalar

dim=0

Reduce along rows (vertically)

dim=1

Reduce along cols (horizontally)

# Shape: (3, 4) -> (3,)
x = torch.tensor([[5, 4, 8, 7], [9, 1, 3, 3], [7, 2, 5, 9]])
result = x.sum(dim=1)
# result = [24,16,23]

Quick Check

What is the shape of x.mean(dim=1, keepdim=True) where x has shape (8, 16, 32)?


Comparison and Logical Operations

Comparison operations return boolean tensors, while logical operations combine them. These are essential for masking, filtering, and conditional computation.

Comparison and Logical Operations
🐍comparison_ops.py
7Comparison Operators

Return BoolTensor where each element is True/False based on the comparison.

19Logical AND (&)

Element-wise AND. Both conditions must be True.

26any() and all()

Reduce boolean tensor to scalar. any() = exists True, all() = all are True.

33torch.where()

Ternary operation: selects from first tensor where condition is True, from second tensor otherwise.

32 lines without explanation
1import torch
2
3a = torch.tensor([1, 2, 3, 4, 5])
4b = torch.tensor([3, 3, 3, 3, 3])
5
6# Element-wise comparisons (return BoolTensor)
7eq = a == b        # [False, False, True, False, False]
8ne = a != b        # [True, True, False, True, True]
9lt = a < b         # [True, True, False, False, False]
10le = a <= b        # [True, True, True, False, False]
11gt = a > b         # [False, False, False, True, True]
12ge = a >= b        # [False, False, True, True, True]
13
14# Method equivalents
15eq2 = torch.eq(a, b)
16lt2 = torch.lt(a, b)
17
18# Logical operations
19mask1 = a > 2
20mask2 = a < 5
21combined = mask1 & mask2   # AND: [False, False, True, True, False]
22combined = mask1 | mask2   # OR:  [True, True, True, True, True]
23inverted = ~mask1          # NOT: [True, True, False, False, False]
24xor = mask1 ^ mask2        # XOR: [True, True, False, False, True]
25
26# Any and All (reduction to scalar)
27any_true = mask1.any()     # True (at least one True)
28all_true = mask1.all()     # False (not all True)
29
30# Counting
31count = mask1.sum()        # 3 (number of True values)
32ratio = mask1.float().mean()  # 0.6 (proportion True)
33
34# Where (conditional selection)
35result = torch.where(mask1, a, b)  # [3, 3, 3, 4, 5]
36# If mask1[i] is True, take a[i], else take b[i]

Masking in Practice

🐍masking.py
1# Attention masking
2seq_len = 10
3# Create causal mask (lower triangular)
4causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
5
6# Apply mask to attention scores
7scores = torch.randn(seq_len, seq_len)
8masked_scores = scores.masked_fill(~causal_mask, float('-inf'))
9# Positions where mask is False become -inf
10# After softmax, these become 0
11
12# Padding mask (variable length sequences)
13lengths = torch.tensor([7, 5, 10])  # Actual lengths in batch
14max_len = 10
15# Create mask: True for valid positions
16padding_mask = torch.arange(max_len).expand(3, -1) < lengths.unsqueeze(1)
17# Shape: (3, 10) with True for valid tokens
18
19# Using masks for selective operations
20x = torch.randn(3, 10, 64)  # (batch, seq, features)
21# Zero out padded positions
22x = x * padding_mask.unsqueeze(-1)  # Broadcasting (3, 10, 1)

Einstein Summation: The Universal Notation

Einstein summation notation (einsum) is a powerful and concise way to express many tensor operations. It uses index labels to describe how dimensions combine and contract.

The Einsum Rules

An einsum expression like "ij,jk->ik" means:

  • Each letter (i, j, k) represents a dimension index
  • Repeated indices across inputs are summed (contracted)
  • Output indices after -> define the result shape
  • Indices not in output are implicitly summed over
ij,jk->ik:Cik=jAijBjk\texttt{ij,jk->ik}: \quad C_{ik} = \sum_j A_{ij} B_{jk}

Common Patterns

EinsumOperationEquivalent
'ij->ji'TransposeA.T
'ii->i'Diagonaltorch.diag(A)
'ii->'Tracetorch.trace(A)
'ij->'Sum allA.sum()
'ij->i'Sum rowsA.sum(dim=1)
'i,i->'Dot producttorch.dot(a, b)
'i,j->ij'Outer producttorch.outer(a, b)
'ij,jk->ik'Matrix multiplyA @ B
'bij,bjk->bik'Batch matmultorch.bmm(A, B)
'bhqd,bhkd->bhqk'Attention scoresQ @ K.transpose(-2,-1)
Einstein Summation Examples
🐍einsum_examples.py
8Basic MatMul

'ij,jk->ik': j is repeated across inputs but not in output, so it's contracted (summed).

13Batch MatMul

'bij,bjk->bik': b is in all parts (both inputs and output), so it's a batch dimension.

20Attention Scores

'bhqd,bhkd->bhqk': d is contracted, giving Q-K dot products for each head and batch.

24Attention Output

'bhqk,bhkd->bhqd': k is contracted, weighting values by attention weights.

30Bilinear Form

'i,ij,j->': Both i and j are contracted, resulting in scalar x^T A y.

29 lines without explanation
1import torch
2
3A = torch.randn(3, 4)
4B = torch.randn(4, 5)
5C = torch.randn(3, 5)
6
7# Matrix multiplication
8result = torch.einsum('ij,jk->ik', A, B)  # Same as A @ B
9
10# Batch matrix multiply
11batch_A = torch.randn(8, 3, 4)
12batch_B = torch.randn(8, 4, 5)
13batch_C = torch.einsum('bij,bjk->bik', batch_A, batch_B)
14
15# Attention computation in transformers
16Q = torch.randn(2, 8, 64, 32)  # (B, H, Q, D)
17K = torch.randn(2, 8, 64, 32)  # (B, H, K, D)
18V = torch.randn(2, 8, 64, 32)  # (B, H, K, D)
19
20# QK^T: attention scores
21scores = torch.einsum('bhqd,bhkd->bhqk', Q, K)
22
23# Attention(Q,K,V) = softmax(QK^T / sqrt(d)) V
24attn_weights = torch.softmax(scores / (32 ** 0.5), dim=-1)
25output = torch.einsum('bhqk,bhkd->bhqd', attn_weights, V)
26
27# Bilinear form: x^T A y
28x = torch.randn(3)
29y = torch.randn(4)
30A = torch.randn(3, 4)
31bilinear = torch.einsum('i,ij,j->', x, A, y)  # Scalar
32
33# Hadamard product (element-wise) with reduction
34element_sum = torch.einsum('ij,ij->', A[:, :4], B[:3, :][:, :4])  # Sum of A*B

When to Use Einsum

Use einsum when the operation is complex or involves unusual contractions. For standard operations like matmul, use @—it's equally efficient and more readable. Einsum shines for attention mechanisms, tensor decompositions, and custom contractions.

Interactive: Einsum Explorer

Explore Einstein summation notation interactively. Select from common patterns, see how indices map to operations, and understand what gets contracted versus preserved.

Einstein Summation Explorer

Master the universal notation for tensor operations

Einsum Rules

  • • Each letter represents a dimension index
  • Repeated indices in inputs are summed (contracted)
  • Output indices define the result shape
  • Free indices (appearing once) are preserved in output
  • • Omitting an index from output = sum over that dimension

Transpose

basic
ij
ji
Preserved
Contracted (summed)
Output
Shapes
Input 1:(m, n)
Output:(n, m)
PyTorch Code
torch.einsum("ij->ji", ...)
Equivalent: A.T or A.transpose(0, 1)
Description

Swap dimensions i and j

Use Case

Preparing matrices for multiplication

Try Your Own


Linear Algebra Operations

PyTorch includes a rich set of linear algebra operations in torch.linalg, essential for solving systems, computing decompositions, and understanding data structure.

Matrix Properties and Norms

🐍linalg_basics.py
1import torch
2
3A = torch.randn(3, 3)
4
5# Matrix properties
6det = torch.linalg.det(A)           # Determinant
7rank = torch.linalg.matrix_rank(A)  # Rank
8trace = torch.trace(A)              # Sum of diagonal
9
10# Norms
11vec = torch.randn(5)
12l1_norm = torch.linalg.norm(vec, ord=1)    # Sum of absolute values
13l2_norm = torch.linalg.norm(vec, ord=2)    # Euclidean norm
14linf_norm = torch.linalg.norm(vec, ord=float('inf'))  # Max absolute
15
16# Matrix norms
17fro_norm = torch.linalg.norm(A, ord='fro')  # Frobenius (L2 on flattened)
18nuc_norm = torch.linalg.norm(A, ord='nuc')  # Nuclear (sum of singular values)
19spectral = torch.linalg.norm(A, ord=2)      # Spectral (largest singular value)
20
21# Condition number (sensitivity to perturbation)
22cond = torch.linalg.cond(A)

Decompositions

🐍decompositions.py
1import torch
2
3A = torch.randn(4, 3)
4B = torch.randn(3, 3)
5
6# SVD: A = U @ S @ Vh
7U, S, Vh = torch.linalg.svd(A)
8# U: (4, 4), S: (3,), Vh: (3, 3)
9
10# Eigendecomposition (square matrices)
11eigenvalues, eigenvectors = torch.linalg.eig(B)
12# For symmetric matrices, use eigh (real eigenvalues)
13sym = B @ B.T
14eigenvalues, eigenvectors = torch.linalg.eigh(sym)
15
16# QR decomposition: A = Q @ R
17Q, R = torch.linalg.qr(A)
18# Q: orthogonal, R: upper triangular
19
20# Cholesky decomposition (positive definite)
21pos_def = sym + 3 * torch.eye(3)  # Make positive definite
22L = torch.linalg.cholesky(pos_def)
23# L @ L.T = pos_def
24
25# LU decomposition
26P, L, U = torch.linalg.lu(B)
27# P @ L @ U = B

Solving Linear Systems

🐍solve_systems.py
1import torch
2
3# Solve Ax = b
4A = torch.randn(3, 3)
5b = torch.randn(3, 1)
6
7# Direct solve (uses LU decomposition internally)
8x = torch.linalg.solve(A, b)  # x such that Ax = b
9
10# Least squares (overdetermined system)
11A = torch.randn(5, 3)  # More equations than unknowns
12b = torch.randn(5, 1)
13x, residuals, rank, s = torch.linalg.lstsq(A, b)
14
15# Matrix inverse (avoid when possible - solve is more stable)
16A_inv = torch.linalg.inv(A[:3, :3])  # Square matrix required
17
18# Pseudoinverse (works for any shape)
19A_pinv = torch.linalg.pinv(A)
20
21# Batched solve
22A_batch = torch.randn(10, 3, 3)  # 10 systems
23b_batch = torch.randn(10, 3, 1)
24x_batch = torch.linalg.solve(A_batch, b_batch)  # Solves all 10

Numerical Stability

Direct matrix inversion (torch.linalg.inv) is rarely the best approach. Use torch.linalg.solve(A, b) instead of inv(A) @ b—it's more numerically stable and often faster.

In-place Operations

In-place operations modify a tensor directly without creating a new one. They're identified by a trailing underscore (e.g., add_()).

🐍in_place.py
1import torch
2
3x = torch.tensor([1.0, 2.0, 3.0])
4
5# In-place operations (note the underscore)
6x.add_(1)       # x is now [2, 3, 4]
7x.mul_(2)       # x is now [4, 6, 8]
8x.fill_(0)      # x is now [0, 0, 0]
9x.zero_()       # Same as fill_(0)
10
11# Clamp in-place
12x = torch.randn(5)
13x.clamp_(min=0)  # ReLU in-place
14
15# In-place random fill
16x.uniform_(0, 1)  # Fill with uniform random
17x.normal_(0, 1)   # Fill with normal random
18
19# Copy in-place
20src = torch.tensor([1., 2., 3.])
21x.copy_(src)     # Copy src into x
22
23# Indexing with in-place modification
24y = torch.zeros(5)
25y[2:4] = torch.tensor([7., 8.])  # Modifies y in-place

In-place and Autograd

In-place operations on tensors that require gradients can break the computation graph! PyTorch will raise an error if the modification would invalidate gradients. Use out-of-place operations during training.
🐍autograd_inplace.py
1# This will cause an error during backward:
2x = torch.randn(3, requires_grad=True)
3x.add_(1)  # RuntimeError!
4
5# Safe alternative:
6x = torch.randn(3, requires_grad=True)
7y = x + 1  # Creates new tensor, grad graph preserved
8
9# In-place is OK for non-leaf tensors in some cases,
10# but it's safest to avoid during training

Performance Considerations

Understanding how tensor operations use memory and compute resources helps you write efficient deep learning code.

Memory Layout and Contiguity

🐍memory_performance.py
1import torch
2
3x = torch.randn(1000, 1000)
4
5# Transpose creates a non-contiguous view
6y = x.T
7print(y.is_contiguous())  # False
8
9# Some operations require contiguous memory
10# contiguous() copies data into contiguous layout
11y_contig = y.contiguous()
12
13# reshape() vs view()
14# view() requires contiguous, reshape() handles either
15z = x.T.reshape(-1)  # Works (copies internally)
16# z = x.T.view(-1)   # Error: not contiguous
17
18# Tip: Check contiguity before performance-critical loops
19if not x.is_contiguous():
20    x = x.contiguous()

Fused Operations

Some operations are optimized to execute together, avoiding intermediate memory allocation:

🐍fused_ops.py
1# Fused multiply-add (more efficient)
2result = torch.addmm(bias, input, weight)  # bias + input @ weight
3
4# Fused operations in F.linear
5output = F.linear(input, weight, bias)  # Optimized internally
6
7# Fused attention (PyTorch 2.0+)
8from torch.nn.functional import scaled_dot_product_attention
9output = scaled_dot_product_attention(Q, K, V)  # Fused, memory efficient
10
11# torch.compile for automatic fusion (PyTorch 2.0+)
12@torch.compile
13def my_function(x, y):
14    return (x * y).sum(dim=-1).mean()
15# Compiler fuses operations automatically

Memory-Efficient Patterns

🐍memory_efficient.py
1import torch
2
3# Bad: Creates intermediate tensors
4def inefficient(x, y, z):
5    temp1 = x * y      # Allocates temp1
6    temp2 = temp1 + z  # Allocates temp2
7    return temp2.sum()
8
9# Better: Use in-place where safe (but not with autograd)
10def efficient_no_grad(x, y, z):
11    with torch.no_grad():
12        x.mul_(y)
13        x.add_(z)
14        return x.sum()
15
16# Best: Let torch.compile optimize
17@torch.compile
18def optimized(x, y, z):
19    return ((x * y) + z).sum()  # Compiler fuses
20
21# Gradient checkpointing for large models
22from torch.utils.checkpoint import checkpoint
23# Recomputes activations during backward instead of storing
24output = checkpoint(layer, input)  # Trades compute for memory

torch.compile: The PyTorch 2.0 Compiler

torch.compile is PyTorch's JIT compiler that can dramatically speed up your models by automatically optimizing and fusing operations.

🐍torch_compile.py
1import torch
2
3# Basic usage - wrap your model or function
4model = MyModel()
5compiled_model = torch.compile(model)
6
7# Or use as a decorator
8@torch.compile
9def my_function(x, y):
10    return (x.matmul(y) + x).relu()
11
12# Compilation modes
13model_default = torch.compile(model)  # Balanced
14model_fast = torch.compile(model, mode="reduce-overhead")  # Minimize latency
15model_max = torch.compile(model, mode="max-autotune")  # Best perf, slow compile
16
17# Disable for debugging
18model_debug = torch.compile(model, disable=True)
19
20# Full options for production
21compiled = torch.compile(
22    model,
23    mode="max-autotune",
24    fullgraph=True,  # Error if can't compile entire graph
25    dynamic=False,    # Static shapes (faster)
26)
ModeCompile TimeRuntimeUse Case
defaultMediumGoodGeneral use, first try this
reduce-overheadFastGoodMany small operations, inference
max-autotuneSlowBestProduction training, benchmarking

When torch.compile May Not Help

  • Dynamic shapes: Variable batch sizes or sequence lengths trigger recompilation
  • Python control flow: if/else based on tensor values breaks the graph
  • Small models: Compilation overhead may exceed speedup
  • Custom C++ ops: May not be fusible by the compiler

Profiling

Use torch.profiler to identify bottlenecks. GPU operations are asynchronous, so wall-clock time can be misleading. The profiler shows actual GPU utilization and memory usage.

Knowledge Check

Test your understanding of tensor operations with this comprehensive quiz covering broadcasting, matrix multiplication, einsum, and reductions.

Tensor Operations Quiz

Question 1 of 12Score: 0
Broadcasting

What is the result shape when adding tensors of shapes (3, 1) and (1, 4)?


Summary

Tensor operations are the building blocks of all deep learning computation. Here's what we covered:

CategoryKey OperationsWhen to Use
Element-wise+, *, exp, relu, tanhActivations, normalization, feature transforms
BroadcastingAutomatic shape alignmentPer-channel ops, batched operations, masking
Matrix Multiply@, mm, bmm, matmulLinear layers, attention, projections
Reductionssum, mean, max, softmaxLoss computation, pooling, normalization
Comparison==, <, >, where, masked_fillMasking, conditional logic, filtering
Einsumtorch.einsum('...',)Complex contractions, custom operations
Linear Algebrasolve, svd, eig, normOptimization, analysis, preprocessing

Key Takeaways

  1. Broadcasting aligns shapes by right-padding with 1s and expanding. Understand it deeply to avoid bugs.
  2. Use @ for matrix multiplication—it handles broadcasting and is the most Pythonic syntax.
  3. Einsum is powerful for expressing complex tensor operations concisely, especially in attention mechanisms.
  4. Pay attention to dimension ordering in reductions—dim=0 is usually batch, dim=-1 is features.
  5. Avoid in-place operations during training to prevent autograd issues.

Exercises

Conceptual Questions

  1. Explain why torch.mm(A, B) fails but A @ B works when A has shape (1, 3, 4) and B has shape (4, 5).
  2. Why does einsum 'ij,jk->ik' represent matrix multiplication? What would 'ij,jk->ijk' compute?
  3. How would you use broadcasting to subtract the mean from each row of a matrix without explicitly reshaping?

Coding Exercises

  1. Broadcasting Practice: Given a batch of images with shape (32, 3, 224, 224), write code to normalize each channel independently using per-channel mean and std vectors.
  2. Attention Without Loops: Implement scaled dot-product attention using only tensor operations (no loops). Input shapes: Q, K, V all (batch, seq_len, d_model).
  3. Einsum Mastery: Rewrite these operations using einsum:
    • Batch matrix trace: sum of diagonals for each matrix in a batch (B, N, N) → (B,)
    • Batched outer product: (B, M) and (B, N) → (B, M, N)
    • Multi-head attention scores: Q(B,H,Q,D) × K(B,H,K,D)^T → (B,H,Q,K)
  4. Memory Investigation: Create a 3D tensor, perform various reshape and transpose operations, and determine which create views vs copies by checking data_ptr() and is_contiguous().

Challenge Exercise

Implement Layer Normalization from Scratch: Using only tensor operations covered in this section, implement layer normalization that normalizes over the last D dimensions with learnable affine parameters.

🐍layer_norm_challenge.py
1def layer_norm(x, normalized_shape, weight, bias, eps=1e-5):
2    """
3    Args:
4        x: Input tensor of any shape
5        normalized_shape: tuple of ints (last D dimensions to normalize)
6        weight: Scale parameter (same shape as normalized_shape)
7        bias: Shift parameter (same shape as normalized_shape)
8        eps: Small constant for numerical stability
9    Returns:
10        Normalized tensor with same shape as x
11    """
12    # Your implementation here
13    # Hint: Use reduction over last D dims, broadcasting, and element-wise ops
14    pass

In the next section, we'll dive into advanced indexing, slicing, and reshaping operations that give you precise control over tensor data.