Chapter 4
25 min read
Section 27 of 178

Indexing, Slicing, and Reshaping

PyTorch Fundamentals

Learning Objectives

By the end of this section, you will be able to:

  1. Master tensor indexing using single indices, slices, and multi-dimensional access
  2. Apply advanced slicing with negative indices, steps, and ellipsis notation
  3. Use boolean masking to filter tensors based on conditions
  4. Perform fancy indexing to select arbitrary elements in any order
  5. Understand views vs copies and avoid common memory-related bugs
  6. Reshape tensors using view, reshape, squeeze, unsqueeze, permute, and transpose
  7. Apply these operations in real deep learning scenarios like batch processing and attention mechanisms
Why This Matters: Every neural network forward pass involves extensive tensor manipulation. Whether you're preparing data batches, implementing attention mechanisms, or building custom layers, fluent tensor indexing and reshaping skills are essential. These operations are the “verbs” of deep learning programming.

The Big Picture

Imagine you have a dataset of 1000 images, each with 3 color channels (RGB) and dimensions 224×224 pixels. In PyTorch, this data lives in a 4D tensor with shape (1000, 3, 224, 224). To work with this data, you need to:

  • Extract a single image for visualization
  • Select a batch of 32 images for training
  • Get only the red channel across all images
  • Crop the center 100×100 pixels from each image
  • Reshape data for different layer types

These operations—indexing, slicing, and reshaping—are the fundamental tools that make this possible. They evolved from NumPy's powerful array operations, which in turn were inspired by MATLAB and APL. PyTorch tensors support nearly identical syntax, making it easy to transition from NumPy while gaining GPU acceleration.

A Historical Perspective

The concept of array slicing originated in APL (A Programming Language) in the 1960s, designed by Kenneth Iverson. The notation was later refined in languages like MATLAB and Fortran, eventually becoming the start:stop:step syntax we use today. NumPy standardized this for Python, and PyTorch adopted it for deep learning workloads.

Understanding that views (memory-sharing references) vs copies (independent data) is critical for both correctness and performance. This distinction becomes especially important when working with large tensors on GPUs, where unnecessary copies can exhaust memory.


Basic Indexing

Basic indexing in PyTorch follows Python's zero-based indexing convention. Each dimension can be accessed with an integer index, starting from 0 for the first element.

Single Element Access

To access a single element, provide an index for each dimension:

Single Element Access
🐍basic_indexing.py
8Zero-Based Indexing

Row 1 is the second row (0-indexed), column 2 is the third column. Element at that position is 7.

9Getting Python Values

Use .item() to extract a Python scalar from a 0-dimensional tensor. Essential when you need a regular Python number.

13Chained Indexing

matrix[1][2] first returns row 1, then indexes column 2. Works but is slightly less efficient than matrix[1, 2].

17Negative Indices

-1 refers to the last element, -2 to second-to-last, etc. Works for any dimension.

15 lines without explanation
1import torch
2
3# Create a 2D tensor (matrix)
4matrix = torch.tensor([[1, 2, 3, 4],
5                       [5, 6, 7, 8],
6                       [9, 10, 11, 12]])
7
8# Access element at row 1, column 2
9element = matrix[1, 2]
10print(element)        # tensor(7)
11print(element.item()) # 7 (Python int)
12
13# Alternative chained syntax
14element = matrix[1][2]
15print(element)        # tensor(7)
16
17# Access last element with negative indexing
18last = matrix[-1, -1]
19print(last)           # tensor(12)

Selecting Rows and Columns

Using a single index reduces the tensor's dimensionality by one:

🐍select_rows_cols.py
1matrix = torch.arange(12).reshape(3, 4)
2print(matrix)
3# tensor([[ 0,  1,  2,  3],
4#         [ 4,  5,  6,  7],
5#         [ 8,  9, 10, 11]])
6
7# Select entire row (reduces to 1D)
8row_1 = matrix[1]
9print(row_1.shape)  # torch.Size([4])
10print(row_1)        # tensor([4, 5, 6, 7])
11
12# Select entire column (reduces to 1D)
13col_2 = matrix[:, 2]  # : means "all rows"
14print(col_2.shape)  # torch.Size([3])
15print(col_2)        # tensor([2, 6, 10])
16
17# Keep dimension with slicing
18row_1_2d = matrix[1:2]  # Slice instead of index
19print(row_1_2d.shape)   # torch.Size([1, 4]) - 2D!

Index vs Slice for Dimension

Using a single index [1] reduces dimensionality. Using a slice [1:2] keeps the dimension. This distinction is important when operations require specific tensor ranks.

Multi-Dimensional Indexing

For tensors with 3 or more dimensions, extend the same pattern. Each comma-separated index addresses a different dimension:

🐍multi_dim_indexing.py
1# 3D tensor: (depth, rows, cols)
2tensor = torch.arange(24).reshape(2, 3, 4)
3print(tensor.shape)  # torch.Size([2, 3, 4])
4
5# Select first "slice" (depth=0)
6slice_0 = tensor[0]
7print(slice_0.shape)  # torch.Size([3, 4])
8
9# Select specific element
10element = tensor[1, 2, 3]  # depth=1, row=2, col=3
11print(element)        # tensor(23)
12
13# Mixed: one slice, specific row, all columns
14subset = tensor[0, 1, :]
15print(subset.shape)   # torch.Size([4])
16print(subset)         # tensor([4, 5, 6, 7])

Quick Check

Given tensor = torch.arange(60).reshape(3, 4, 5), what is the shape of tensor[1]?


Interactive: Explore Indexing

Use the interactive visualizer below to build intuition for tensor indexing. Adjust the indices and observe which elements are selected in real-time.

🎯Interactive Tensor Indexing

Select Element: tensor[1, 2]

[0]
[1]
[2]
[3]
[4]
[5]
[0]
0
1
2
3
4
5
[1]
6
7
8
9
10
11
[2]
12
13
14
15
16
17
[3]
18
19
20
21
22
23
[4]
24
25
26
27
28
29

PyTorch Code

import torch

# Create a 5x6 tensor
tensor = torch.arange(30).reshape(5, 6)

# Access element at row 1, column 2
element = tensor[1, 2]
print(element)  # Output: 8

# Alternative syntax
element = tensor[1][2]
print(element)  # Output: 8

Result: tensor[1, 2] = 8


Slicing Operations

Slicing extracts a contiguous sub-tensor using the start:stop:step syntax. This powerful notation lets you select ranges, reverse order, and skip elements.

The Slice Syntax

tensor[start:stop:step]\text{tensor}[\text{start}:\text{stop}:\text{step}]
ComponentDescriptionDefaultExample
startBeginning index (inclusive)02: starts at index 2
stopEnding index (exclusive)end:5 ends before index 5
stepIncrement between elements1::2 every other element
Slice Syntax Variations
🐍slicing_basics.py
4Standard Range

Elements from index 2 up to (not including) index 7.

5From Start

Omitting start defaults to 0. :5 means 0:5.

6To End

Omitting stop goes to the end. 5: includes all from index 5.

7Step

::2 takes every 2nd element starting from 0.

11Negative Start

-3: means 'from 3rd-to-last to end'.

15Reverse Order

Step of -1 reverses the tensor. Very efficient (returns a view).

12 lines without explanation
1tensor = torch.arange(10)  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
2
3# Basic slices
4print(tensor[2:7])     # [2, 3, 4, 5, 6]
5print(tensor[:5])      # [0, 1, 2, 3, 4]
6print(tensor[5:])      # [5, 6, 7, 8, 9]
7print(tensor[::2])     # [0, 2, 4, 6, 8]
8print(tensor[1::2])    # [1, 3, 5, 7, 9]
9
10# Negative indices in slices
11print(tensor[-3:])     # [7, 8, 9] - last 3
12print(tensor[:-3])     # [0, 1, 2, 3, 4, 5, 6] - all but last 3
13print(tensor[-5:-2])   # [5, 6, 7]
14
15# Negative step (reverse)
16print(tensor[::-1])    # [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]
17print(tensor[7:2:-1])  # [7, 6, 5, 4, 3]
18print(tensor[::-2])    # [9, 7, 5, 3, 1]

Multi-Dimensional Slicing

Apply slicing to multiple dimensions simultaneously:

🐍multidim_slicing.py
1matrix = torch.arange(20).reshape(4, 5)
2print(matrix)
3# tensor([[ 0,  1,  2,  3,  4],
4#         [ 5,  6,  7,  8,  9],
5#         [10, 11, 12, 13, 14],
6#         [15, 16, 17, 18, 19]])
7
8# Rows 1-2, columns 2-4
9submatrix = matrix[1:3, 2:5]
10print(submatrix)
11# tensor([[ 7,  8,  9],
12#         [12, 13, 14]])
13
14# Every other row, every other column
15checkerboard = matrix[::2, ::2]
16print(checkerboard)
17# tensor([[ 0,  2,  4],
18#         [10, 12, 14]])
19
20# Last 2 rows, all columns reversed
21flipped = matrix[-2:, ::-1]
22print(flipped)
23# tensor([[14, 13, 12, 11, 10],
24#         [19, 18, 17, 16, 15]])

The Ellipsis (...)

The ellipsis notation ... is shorthand for “as many full slices as needed”:

🐍ellipsis.py
1# 4D tensor: (batch, channels, height, width)
2images = torch.randn(32, 3, 224, 224)
3
4# These are equivalent for getting the last dimension
5last_col = images[:, :, :, -1]
6last_col = images[..., -1]    # Cleaner!
7print(last_col.shape)  # torch.Size([32, 3, 224])
8
9# First batch item, all other dims
10first = images[0, ...]
11print(first.shape)  # torch.Size([3, 224, 224])
12
13# First and last channel only
14subset = images[:, [0, 2], ...]
15print(subset.shape)  # torch.Size([32, 2, 224, 224])

When to Use Ellipsis

Use ... when you want to index specific dimensions while keeping the rest unchanged, especially in generic functions that handle tensors of varying rank.

Interactive: Advanced Slicing

Experiment with different slicing patterns, boolean masks, and fancy indexing in the interactive demo below.

✂️Advanced Tensor Slicing

Select a Slice Pattern

0
1
2
3
4
5
0
0
1
2
3
4
5
1
6
7
8
9
10
11
2
12
13
14
15
16
17
3
18
19
20
21
22
23
4
24
25
26
27
28
29

All rows, column 2: Select entire column 2

PyTorch Code

import torch

tensor = torch.arange(30).reshape(5, 6)

# All rows, column 2
result = tensor[:, 2]
print(result)
print(result.shape)

# Result:
# Shape: (5, 1)

Slice Syntax: start:stop:step

  • start - beginning index (default: 0)
  • stop - ending index, exclusive (default: end)
  • step - increment (default: 1, negative = reverse)
  • Omitting a value uses the default

Boolean (Mask) Indexing

Boolean indexing uses a tensor of True/False values to select elements. This is extremely powerful for filtering data based on conditions.

Creating and Applying Masks

Boolean Mask Operations
🐍boolean_indexing.py
4Creating Masks

Comparison operators create boolean tensors of the same shape.

8Applying Masks

Boolean indexing always returns a 1D tensor containing the matching elements.

11Compound Conditions

Use & (and), | (or), ~ (not). MUST use parentheses around each condition due to operator precedence.

18Direct Condition

Most common pattern: tensor[tensor > value] directly in one line.

17 lines without explanation
1tensor = torch.tensor([1, 4, 2, 7, 3, 9, 5])
2
3# Create a boolean mask
4mask = tensor > 4
5print(mask)  # tensor([False, False, False, True, False, True, True])
6
7# Apply the mask - returns 1D tensor of matching elements
8filtered = tensor[mask]
9print(filtered)  # tensor([7, 9, 5])
10
11# Combine conditions (use & for AND, | for OR)
12mask_range = (tensor > 2) & (tensor < 8)
13print(tensor[mask_range])  # tensor([4, 7, 3, 5])
14
15# Negate with ~
16mask_not = ~(tensor > 5)
17print(tensor[mask_not])  # tensor([1, 4, 2, 3, 5])
18
19# Direct condition (most common pattern)
20evens = tensor[tensor % 2 == 0]
21print(evens)  # tensor([4, 2])

Boolean Indexing Creates Copies

Unlike basic slicing which returns views, boolean indexing always returns a copy. This is because the selected elements may not be contiguous in memory. Modifying the result will NOT affect the original tensor.

Multi-Dimensional Boolean Indexing

🐍multidim_boolean.py
1matrix = torch.randn(4, 5)
2
3# Mask on entire tensor (flattens result)
4mask = matrix > 0
5positive = matrix[mask]
6print(positive.shape)  # 1D: number of positive elements
7
8# Mask on specific dimension
9# Select rows where first column > 0
10row_mask = matrix[:, 0] > 0
11selected_rows = matrix[row_mask]
12print(selected_rows.shape)  # (num_matching_rows, 5)
13
14# Modify elements in-place using mask
15matrix[matrix < -1] = -1  # Clamp minimum to -1
16matrix[matrix > 1] = 1    # Clamp maximum to 1
17
18# torch.where for conditional assignment
19result = torch.where(matrix > 0, matrix, torch.zeros_like(matrix))
20# Like ReLU: keep positives, zero out negatives

Real-World Example: GPT-2's Attention Mask

In transformer models like GPT-2, boolean masking prevents the model from attending to future tokens during training:

🐍attention_mask.py
1# Causal mask for attention
2seq_len = 5
3# Create lower triangular mask (True = can attend)
4causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
5print(causal_mask)
6# tensor([[ True, False, False, False, False],
7#         [ True,  True, False, False, False],
8#         [ True,  True,  True, False, False],
9#         [ True,  True,  True,  True, False],
10#         [ True,  True,  True,  True,  True]])
11
12# In attention: mask out future positions
13attention_scores = torch.randn(seq_len, seq_len)
14attention_scores = attention_scores.masked_fill(~causal_mask, float('-inf'))
15# After softmax, -inf becomes 0 probability

Fancy Indexing

Fancy indexing uses lists or tensors of indices to select elements in any order, potentially with repetition.

Fancy Indexing Examples
🐍fancy_indexing.py
4Index Reordering

Elements are returned in the order specified. Indices can repeat.

132D Coordinate Access

Providing matching-length index tensors for each dimension selects specific coordinate pairs.

19Row Selection

Selecting rows by indices. Result shape is (num_indices, num_cols).

21 lines without explanation
1tensor = torch.tensor([10, 20, 30, 40, 50])
2
3# Index with a list
4indices = [0, 3, 1, 1]
5selected = tensor[indices]
6print(selected)  # tensor([10, 40, 20, 20])
7
8# Index with a tensor
9idx_tensor = torch.tensor([4, 2, 0])
10selected = tensor[idx_tensor]
11print(selected)  # tensor([50, 30, 10])
12
13# 2D fancy indexing
14matrix = torch.arange(12).reshape(3, 4)
15# Select elements at (0,1), (1,2), (2,3)
16row_idx = torch.tensor([0, 1, 2])
17col_idx = torch.tensor([1, 2, 3])
18elements = matrix[row_idx, col_idx]
19print(elements)  # tensor([1, 6, 11])
20
21# Select multiple rows in any order
22row_indices = [2, 0, 2, 1]
23rows = matrix[row_indices]
24print(rows.shape)  # torch.Size([4, 4])

Advanced Fancy Indexing Patterns

🐍advanced_fancy.py
1# Gather operation (common in embeddings)
2embeddings = torch.randn(1000, 256)  # 1000 tokens, 256-dim
3token_ids = torch.tensor([42, 7, 256, 999])
4selected = embeddings[token_ids]
5print(selected.shape)  # torch.Size([4, 256])
6
7# torch.gather for more control
8# Useful for selecting along specific dimension
9values = torch.tensor([[1, 2], [3, 4], [5, 6]])
10indices = torch.tensor([[0, 0], [1, 0], [0, 1]])
11gathered = torch.gather(values, dim=1, index=indices)
12print(gathered)  # tensor([[1, 1], [4, 3], [5, 6]])
13
14# torch.take (flat indexing)
15flat_indices = torch.tensor([0, 3, 5])
16taken = torch.take(values, flat_indices)
17print(taken)  # tensor([1, 4, 6]) - flat indexing into all elements

Fancy Indexing Creates Copies

Like boolean indexing, fancy indexing always returns a copy. The selected elements are gathered into a new tensor with its own memory.

Views vs Copies: Critical Concepts

Understanding when operations create views (sharing memory) vs copies (independent memory) is crucial for correctness and performance. This concept is fundamental to efficient tensor programming.

What is a View?

A view is a tensor that shares its underlying data storage with another tensor. Modifying one affects the other. Views are memory-efficient because no data is copied.

OperationReturns View?Notes
Basic slicing tensor[1:5]✅ ViewContiguous subset
view(), reshape() (if contiguous)✅ ViewJust changes shape metadata
transpose(), permute(), T✅ ViewChanges strides only
squeeze(), unsqueeze()✅ ViewAdds/removes dim of size 1
expand()✅ ViewBroadcast to larger size
Boolean indexing tensor[mask]❌ CopyNon-contiguous selection
Fancy indexing tensor[[0,2,1]]❌ CopyArbitrary element selection
clone()❌ CopyExplicit copy
contiguous()Maybe CopyCopy if not already contiguous
reshape() (if non-contiguous)❌ CopyMust copy to make contiguous

Detecting Views

🐍detecting_views.py
1x = torch.arange(12).reshape(3, 4)
2
3# Create a view via slicing
4y = x[1:3, 1:3]
5
6# Check if they share storage
7print(x.data_ptr() == y.data_ptr())  # True (for view)
8print(x.storage().data_ptr() == y.storage().data_ptr())  # True
9
10# Views share the underlying storage
11# but may have different data_ptr if offset
12
13# Modify the view - affects original!
14y[0, 0] = 999
15print(x[1, 1])  # tensor(999) - original changed!
16
17# Create a copy to avoid this
18z = x[1:3, 1:3].clone()
19z[0, 0] = -1
20print(x[1, 1])  # tensor(999) - unchanged

The Contiguity Problem

A tensor is contiguous when its elements are stored in memory in row-major (C-style) order without gaps. Some operations break contiguity:

🐍contiguity.py
1x = torch.arange(12).reshape(3, 4)
2print(x.is_contiguous())  # True
3
4# Transpose breaks contiguity
5y = x.T
6print(y.is_contiguous())  # False
7print(y.stride())         # (1, 4) instead of (3, 1)
8
9# view() fails on non-contiguous tensors
10try:
11    z = y.view(12)  # RuntimeError!
12except RuntimeError as e:
13    print("Error:", e)
14
15# Solutions:
16z = y.contiguous().view(12)  # Make contiguous, then view
17z = y.reshape(12)             # reshape handles it automatically
18
19# Check if operations preserve contiguity
20a = x[:, ::2]  # Strided access
21print(a.is_contiguous())  # False (gaps in memory)

Performance Impact

Non-contiguous tensors can be slower for some operations because memory access is not sequential. If you're doing heavy computation on a non-contiguous tensor, consider calling .contiguous() once upfront.

Quick Check

After x = torch.randn(4, 4); y = x.T[1:3, 1:3].clone(), does modifying y affect x?


Reshape Operations

Reshaping changes how tensor elements are organized into dimensions without changing the total number of elements. These operations are essential for preparing data for different layer types.

Core Reshape Functions

Reshape Operations
🐍reshape_ops.py
4reshape()

Most flexible. Returns view if possible, copy if tensor is non-contiguous.

6The -1 Trick

One dimension can be -1, meaning 'infer from total elements'. 24 elements, 6 columns → 4 rows.

9view()

More restrictive than reshape. Guarantees no copy but fails on non-contiguous tensors.

13flatten()

Shorthand for reshape(-1). Can flatten a subset of dimensions with start_dim and end_dim.

18squeeze()

Removes all size-1 dims by default. Pass dim= to remove specific one.

23unsqueeze()

Adds dim of size 1 at specified position. Essential for broadcasting and batch dims.

21 lines without explanation
1x = torch.arange(24)  # 24 elements
2
3# reshape: flexible, may copy if needed
4a = x.reshape(4, 6)
5b = x.reshape(2, 3, 4)
6c = x.reshape(-1, 6)   # -1 infers dimension
7
8# view: requires contiguous, never copies
9v = x.view(4, 6)
10# v = x.reshape(4, 6).T.view(-1)  # Error! T makes it non-contiguous
11
12# flatten: collapse to 1D
13matrix = x.reshape(4, 6)
14flat = matrix.flatten()
15# Partial flatten
16partial = x.reshape(2, 3, 4).flatten(start_dim=1)  # (2, 12)
17
18# squeeze: remove size-1 dimensions
19y = torch.randn(1, 3, 1, 4, 1)
20print(y.squeeze().shape)     # (3, 4)
21print(y.squeeze(0).shape)    # (3, 1, 4, 1) - only dim 0
22print(y.squeeze(-1).shape)   # (1, 3, 1, 4) - only last dim
23
24# unsqueeze: add size-1 dimension
25z = torch.randn(3, 4)
26print(z.unsqueeze(0).shape)  # (1, 3, 4) - batch dim
27print(z.unsqueeze(-1).shape) # (3, 4, 1) - feature dim

Transpose and Permute

For reordering dimensions (not just reshaping):

🐍transpose_permute.py
1# 2D: transpose swaps dimensions
2matrix = torch.randn(3, 4)
3transposed = matrix.T           # or matrix.transpose(0, 1)
4print(transposed.shape)         # torch.Size([4, 3])
5
6# 3D+: transpose swaps exactly two dimensions
7tensor = torch.randn(2, 3, 4)
8t = tensor.transpose(0, 2)      # swap dim 0 and dim 2
9print(t.shape)                  # torch.Size([4, 3, 2])
10
11# permute: reorder ALL dimensions at once
12# Specify new order as arguments
13p = tensor.permute(2, 0, 1)     # (dim2, dim0, dim1)
14print(p.shape)                  # torch.Size([4, 2, 3])
15
16# Common use: NCHW to NHWC conversion
17images = torch.randn(32, 3, 224, 224)  # PyTorch format
18nhwc = images.permute(0, 2, 3, 1)       # TensorFlow format
19print(nhwc.shape)               # torch.Size([32, 224, 224, 3])

Transpose Returns a View

Both transpose() and permute() return views. They only change how dimensions are accessed (via strides), not the underlying data. This makes them very efficient.

Interactive: Reshape Visualizer

Explore how different reshape operations transform tensor shapes and understand when views vs copies are created.

🔄Tensor Reshape Operations
Inputshape: (3, 4)
0
1
2
3
4
5
6
7
8
9
10
11
tensor.flatten()
Outputshape: (12)
0
1
2
3
4
5
6
7
8
9
10
11

Flatten: Collapse all dimensions into one

import torch

# Input: shape (3, 4)
tensor = torch.arange(12).reshape(3, 4)

# Operation: tensor.flatten()
result = tensor.flatten()

print(result.shape)  # torch.Size([12])
print(result.numel())  # 12 (same total elements)

Real-World Applications

These tensor operations appear throughout deep learning. Here are some concrete examples from popular architectures:

1. CNN Feature Extraction (ResNet)

🐍resnet_features.py
1# ResNet outputs feature maps of shape (batch, channels, H, W)
2# Before fully connected layer, we need to flatten
3
4class ResNetClassifier(nn.Module):
5    def __init__(self, num_classes=1000):
6        super().__init__()
7        self.resnet = torchvision.models.resnet50(pretrained=True)
8        self.fc = nn.Linear(2048, num_classes)
9
10    def forward(self, x):
11        # x: (batch, 3, 224, 224)
12        features = self.resnet.layer4(x)  # (batch, 2048, 7, 7)
13
14        # Global average pooling
15        pooled = features.mean(dim=[2, 3])  # (batch, 2048)
16
17        # Alternative: flatten
18        # flat = features.flatten(start_dim=1)  # (batch, 2048*7*7)
19
20        return self.fc(pooled)

2. Multi-Head Attention (Transformer)

One of the most important reshape patterns in modern AI, used in GPT, BERT, and Vision Transformers:

🐍multihead_attention.py
1class MultiHeadAttention(nn.Module):
2    def __init__(self, d_model=512, num_heads=8):
3        super().__init__()
4        self.d_model = d_model
5        self.num_heads = num_heads
6        self.head_dim = d_model // num_heads
7
8        self.W_q = nn.Linear(d_model, d_model)
9        self.W_k = nn.Linear(d_model, d_model)
10        self.W_v = nn.Linear(d_model, d_model)
11        self.W_o = nn.Linear(d_model, d_model)
12
13    def forward(self, x):
14        batch_size, seq_len, _ = x.shape
15
16        # Project queries, keys, values
17        Q = self.W_q(x)  # (batch, seq, d_model)
18        K = self.W_k(x)
19        V = self.W_v(x)
20
21        # Split into heads: reshape + transpose
22        # (batch, seq, d_model) -> (batch, seq, heads, head_dim)
23        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
24        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
25        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
26
27        # Transpose for attention: (batch, heads, seq, head_dim)
28        Q = Q.transpose(1, 2)
29        K = K.transpose(1, 2)
30        V = V.transpose(1, 2)
31
32        # Compute attention (scaled dot-product)
33        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
34        attn = F.softmax(scores, dim=-1)
35        context = torch.matmul(attn, V)  # (batch, heads, seq, head_dim)
36
37        # Merge heads back: transpose + reshape
38        context = context.transpose(1, 2).contiguous()  # (batch, seq, heads, head_dim)
39        context = context.view(batch_size, seq_len, self.d_model)  # (batch, seq, d_model)
40
41        return self.W_o(context)

3. Batch Normalization Statistics

🐍batchnorm_stats.py
1# Computing batch statistics across spatial dimensions
2features = torch.randn(32, 64, 28, 28)  # (N, C, H, W)
3
4# Mean/var per channel across batch and spatial dims
5# Need to normalize over N, H, W (dims 0, 2, 3)
6mean = features.mean(dim=(0, 2, 3), keepdim=True)  # (1, C, 1, 1)
7var = features.var(dim=(0, 2, 3), keepdim=True)
8
9# keepdim=True maintains shape for broadcasting
10normalized = (features - mean) / torch.sqrt(var + 1e-5)
11
12# Alternative: reshape for easier computation
13# (N, C, H, W) -> (N, C, H*W) -> compute over (N, H*W)
14reshaped = features.flatten(start_dim=2)  # (32, 64, 784)
15mean_alt = reshaped.mean(dim=(0, 2))      # (64,)

4. Image Data Augmentation

🐍augmentation.py
1# Random crop and flip using indexing
2def random_crop_flip(images, crop_size=200):
3    """
4    images: (N, C, H, W) tensor
5    Returns cropped and randomly flipped images
6    """
7    N, C, H, W = images.shape
8
9    # Random crop coordinates
10    top = torch.randint(0, H - crop_size, (N,))
11    left = torch.randint(0, W - crop_size, (N,))
12
13    # Use indexing for each image (could vectorize with gather)
14    crops = []
15    for i in range(N):
16        crop = images[i, :, top[i]:top[i]+crop_size, left[i]:left[i]+crop_size]
17
18        # Random horizontal flip using negative stride
19        if torch.rand(1) > 0.5:
20            crop = crop[:, :, ::-1]  # Flip width dimension
21
22        crops.append(crop)
23
24    return torch.stack(crops)  # (N, C, crop_size, crop_size)

Common Pitfalls and Debugging

1. Accidental In-Place Modification via Views

🐍pitfall_views.py
1# ❌ Bug: Modifying a view affects the original
2def bad_normalize(tensor):
3    row = tensor[0]
4    row /= row.sum()  # In-place! Modifies tensor[0]!
5    return row
6
7# ✅ Fix: Clone if you need independence
8def good_normalize(tensor):
9    row = tensor[0].clone()
10    row /= row.sum()  # Only affects the copy
11    return row

2. Shape Mismatch in Matrix Operations

🐍pitfall_shapes.py
1# ❌ Bug: Forgetting to add batch dimension
2image = torch.randn(3, 224, 224)  # Single image
3output = model(image)  # Error! Model expects (N, C, H, W)
4
5# ✅ Fix: Add batch dimension
6image = image.unsqueeze(0)  # (1, 3, 224, 224)
7output = model(image)
8
9# ❌ Bug: Wrong dimension for linear layer
10features = torch.randn(32, 7, 7, 256)  # (N, H, W, C)
11linear = nn.Linear(256, 10)
12# output = linear(features)  # Works but wrong!
13
14# ✅ Fix: Permute to put features last
15features = features.permute(0, 3, 1, 2)  # (N, C, H, W)
16features = features.flatten(1)            # (N, C*H*W)
17output = linear(features)

3. Non-Contiguous Tensor Errors

🐍pitfall_contiguous.py
1# ❌ Bug: view() on non-contiguous tensor
2x = torch.randn(4, 4)
3y = x.T  # Transpose creates non-contiguous view
4z = y.view(16)  # RuntimeError!
5
6# ✅ Fix options:
7z = y.contiguous().view(16)  # Make contiguous first
8z = y.reshape(16)             # reshape handles it
9z = y.flatten()               # Same as reshape(-1)
10
11# Check before view()
12if y.is_contiguous():
13    z = y.view(16)
14else:
15    z = y.reshape(16)

4. Broadcasting Surprises

🐍pitfall_broadcasting.py
1# ❌ Bug: Unexpected broadcasting
2a = torch.randn(3, 4)
3b = torch.randn(4)  # 1D tensor
4
5# This broadcasts b to each row - is this intended?
6c = a + b  # Works: (3, 4) + (4,) -> (3, 4)
7
8# But this might not be what you wanted:
9b2 = torch.randn(3)  # Wrong dimension!
10# c2 = a + b2  # Error! (3, 4) + (3,) doesn't broadcast
11
12# ✅ Fix: Be explicit about dimensions
13b2 = b2.unsqueeze(1)  # (3, 1)
14c2 = a + b2           # Now broadcasts correctly

5. Debugging Shape Issues

🐍debugging_shapes.py
1def debug_shapes(tensor, name="tensor"):
2    """Helper function to print tensor info"""
3    print(f"{name}: shape={tensor.shape}, "
4          f"dtype={tensor.dtype}, "
5          f"device={tensor.device}, "
6          f"contiguous={tensor.is_contiguous()}")
7
8# Use throughout your code
9x = torch.randn(32, 3, 224, 224)
10debug_shapes(x, "input")
11
12x = model.conv1(x)
13debug_shapes(x, "after conv1")
14
15# For complex pipelines, use hooks
16def shape_hook(module, input, output):
17    print(f"{module.__class__.__name__}: "
18          f"{input[0].shape} -> {output.shape}")
19
20for name, layer in model.named_modules():
21    layer.register_forward_hook(shape_hook)

Knowledge Check

Test your understanding of tensor indexing, slicing, and reshaping with this comprehensive quiz.

📝Knowledge CheckQuestion 1 of 10

Given a tensor with shape (4, 5, 6), what does tensor[2, :, 3] return?

tensor = torch.randn(4, 5, 6)
result = tensor[2, :, 3]

Summary

This section covered the essential tensor manipulation operations that form the foundation of all deep learning code:

ConceptKey PointsCommon Use Cases
Basic IndexingZero-based, reduces dimensionalityExtract rows, columns, elements
Slicingstart:stop:step, negative indicesExtract subregions, reverse order
Boolean IndexingMask-based selection, returns copyFilter by condition, thresholding
Fancy IndexingArbitrary order, returns copyReorder elements, embeddings
ViewsShare memory, no copy overheadSlicing, transpose, reshape (if contiguous)
CopiesIndependent memoryclone(), boolean/fancy indexing
reshape/viewChange dimensions, preserve elementsFlatten for FC, batch processing
transpose/permuteReorder dimensionsNCHW↔NHWC, attention heads
squeeze/unsqueezeAdd/remove size-1 dimsBatch dimension, broadcasting

Key Takeaways

  1. Views share memory: Basic slicing, transpose, and reshape (when contiguous) return views. Modifying a view modifies the original.
  2. Boolean and fancy indexing create copies: Always independent memory, safe to modify.
  3. view() requires contiguity: Use reshape() for safety, or call contiguous() first.
  4. -1 infers dimensions: In reshape, -1 means “calculate this dimension automatically.”
  5. Ellipsis (...) is your friend: Use it to index specific dimensions while keeping others.

Exercises

Conceptual Questions

  1. Explain why tensor.T.view(-1) might fail while tensor.T.reshape(-1) succeeds.
  2. Given a batch of images with shape (N, H, W, C) in NHWC format, write the permute operation to convert to NCHW format.
  3. Why does boolean indexing always return a 1D tensor, regardless of the input shape?

Coding Exercises

  1. Checkerboard Selection: Create a 8×8 tensor and use slicing to select only the “white squares” of a checkerboard pattern (every other element in a diagonal pattern).
  2. Batch Cropping: Given a batch of images torch.randn(16, 3, 224, 224), use slicing to extract the center 100×100 pixels from each image.
  3. Attention Reshape: Implement the head-splitting operation for multi-head attention: take a tensor of shape (batch, seq_len, d_model) and reshape it to (batch, num_heads, seq_len, head_dim) where d_model = num_heads × head_dim.
  4. Top-K Selection: Given a tensor of scores, use boolean indexing to select only the elements that are in the top 10% of values.

Challenge Exercise

Implement Efficient Batched Indexing: Write a function that takes a 2D tensor and a tensor of row indices, and efficiently extracts the corresponding rows using fancy indexing. Compare performance with a loop-based approach for large tensors.

🐍challenge.py
1def batched_row_select(tensor, indices):
2    """
3    Args:
4        tensor: shape (M, N)
5        indices: shape (K,) with values in [0, M)
6    Returns:
7        shape (K, N) tensor with selected rows
8    """
9    # Your implementation here
10    pass
11
12# Test
13tensor = torch.randn(1000, 256)
14indices = torch.randint(0, 1000, (100,))
15result = batched_row_select(tensor, indices)
16assert result.shape == (100, 256)

Performance Hint

For the challenge, compare tensor[indices] (fancy indexing) with torch.index_select(tensor, 0, indices). Benchmark both approaches with varying tensor and index sizes.

In the next section, we'll explore GPU computing in PyTorch, learning how to leverage CUDA for massive speedups in tensor operations.