Learning Objectives
By the end of this section, you will be able to:
- Master tensor indexing using single indices, slices, and multi-dimensional access
- Apply advanced slicing with negative indices, steps, and ellipsis notation
- Use boolean masking to filter tensors based on conditions
- Perform fancy indexing to select arbitrary elements in any order
- Understand views vs copies and avoid common memory-related bugs
- Reshape tensors using view, reshape, squeeze, unsqueeze, permute, and transpose
- Apply these operations in real deep learning scenarios like batch processing and attention mechanisms
Why This Matters: Every neural network forward pass involves extensive tensor manipulation. Whether you're preparing data batches, implementing attention mechanisms, or building custom layers, fluent tensor indexing and reshaping skills are essential. These operations are the “verbs” of deep learning programming.
The Big Picture
Imagine you have a dataset of 1000 images, each with 3 color channels (RGB) and dimensions 224×224 pixels. In PyTorch, this data lives in a 4D tensor with shape (1000, 3, 224, 224). To work with this data, you need to:
- Extract a single image for visualization
- Select a batch of 32 images for training
- Get only the red channel across all images
- Crop the center 100×100 pixels from each image
- Reshape data for different layer types
These operations—indexing, slicing, and reshaping—are the fundamental tools that make this possible. They evolved from NumPy's powerful array operations, which in turn were inspired by MATLAB and APL. PyTorch tensors support nearly identical syntax, making it easy to transition from NumPy while gaining GPU acceleration.
A Historical Perspective
The concept of array slicing originated in APL (A Programming Language) in the 1960s, designed by Kenneth Iverson. The notation was later refined in languages like MATLAB and Fortran, eventually becoming the start:stop:step syntax we use today. NumPy standardized this for Python, and PyTorch adopted it for deep learning workloads.
Understanding that views (memory-sharing references) vs copies (independent data) is critical for both correctness and performance. This distinction becomes especially important when working with large tensors on GPUs, where unnecessary copies can exhaust memory.
Basic Indexing
Basic indexing in PyTorch follows Python's zero-based indexing convention. Each dimension can be accessed with an integer index, starting from 0 for the first element.
Single Element Access
To access a single element, provide an index for each dimension:
Selecting Rows and Columns
Using a single index reduces the tensor's dimensionality by one:
1matrix = torch.arange(12).reshape(3, 4)
2print(matrix)
3# tensor([[ 0, 1, 2, 3],
4# [ 4, 5, 6, 7],
5# [ 8, 9, 10, 11]])
6
7# Select entire row (reduces to 1D)
8row_1 = matrix[1]
9print(row_1.shape) # torch.Size([4])
10print(row_1) # tensor([4, 5, 6, 7])
11
12# Select entire column (reduces to 1D)
13col_2 = matrix[:, 2] # : means "all rows"
14print(col_2.shape) # torch.Size([3])
15print(col_2) # tensor([2, 6, 10])
16
17# Keep dimension with slicing
18row_1_2d = matrix[1:2] # Slice instead of index
19print(row_1_2d.shape) # torch.Size([1, 4]) - 2D!Index vs Slice for Dimension
[1] reduces dimensionality. Using a slice [1:2] keeps the dimension. This distinction is important when operations require specific tensor ranks.Multi-Dimensional Indexing
For tensors with 3 or more dimensions, extend the same pattern. Each comma-separated index addresses a different dimension:
1# 3D tensor: (depth, rows, cols)
2tensor = torch.arange(24).reshape(2, 3, 4)
3print(tensor.shape) # torch.Size([2, 3, 4])
4
5# Select first "slice" (depth=0)
6slice_0 = tensor[0]
7print(slice_0.shape) # torch.Size([3, 4])
8
9# Select specific element
10element = tensor[1, 2, 3] # depth=1, row=2, col=3
11print(element) # tensor(23)
12
13# Mixed: one slice, specific row, all columns
14subset = tensor[0, 1, :]
15print(subset.shape) # torch.Size([4])
16print(subset) # tensor([4, 5, 6, 7])Quick Check
Given tensor = torch.arange(60).reshape(3, 4, 5), what is the shape of tensor[1]?
Interactive: Explore Indexing
Use the interactive visualizer below to build intuition for tensor indexing. Adjust the indices and observe which elements are selected in real-time.
Select Element: tensor[1, 2]
PyTorch Code
import torch # Create a 5x6 tensor tensor = torch.arange(30).reshape(5, 6) # Access element at row 1, column 2 element = tensor[1, 2] print(element) # Output: 8 # Alternative syntax element = tensor[1][2] print(element) # Output: 8
Result: tensor[1, 2] = 8
Slicing Operations
Slicing extracts a contiguous sub-tensor using the start:stop:step syntax. This powerful notation lets you select ranges, reverse order, and skip elements.
The Slice Syntax
| Component | Description | Default | Example |
|---|---|---|---|
| start | Beginning index (inclusive) | 0 | 2: starts at index 2 |
| stop | Ending index (exclusive) | end | :5 ends before index 5 |
| step | Increment between elements | 1 | ::2 every other element |
Multi-Dimensional Slicing
Apply slicing to multiple dimensions simultaneously:
1matrix = torch.arange(20).reshape(4, 5)
2print(matrix)
3# tensor([[ 0, 1, 2, 3, 4],
4# [ 5, 6, 7, 8, 9],
5# [10, 11, 12, 13, 14],
6# [15, 16, 17, 18, 19]])
7
8# Rows 1-2, columns 2-4
9submatrix = matrix[1:3, 2:5]
10print(submatrix)
11# tensor([[ 7, 8, 9],
12# [12, 13, 14]])
13
14# Every other row, every other column
15checkerboard = matrix[::2, ::2]
16print(checkerboard)
17# tensor([[ 0, 2, 4],
18# [10, 12, 14]])
19
20# Last 2 rows, all columns reversed
21flipped = matrix[-2:, ::-1]
22print(flipped)
23# tensor([[14, 13, 12, 11, 10],
24# [19, 18, 17, 16, 15]])The Ellipsis (...)
The ellipsis notation ... is shorthand for “as many full slices as needed”:
1# 4D tensor: (batch, channels, height, width)
2images = torch.randn(32, 3, 224, 224)
3
4# These are equivalent for getting the last dimension
5last_col = images[:, :, :, -1]
6last_col = images[..., -1] # Cleaner!
7print(last_col.shape) # torch.Size([32, 3, 224])
8
9# First batch item, all other dims
10first = images[0, ...]
11print(first.shape) # torch.Size([3, 224, 224])
12
13# First and last channel only
14subset = images[:, [0, 2], ...]
15print(subset.shape) # torch.Size([32, 2, 224, 224])When to Use Ellipsis
... when you want to index specific dimensions while keeping the rest unchanged, especially in generic functions that handle tensors of varying rank.Interactive: Advanced Slicing
Experiment with different slicing patterns, boolean masks, and fancy indexing in the interactive demo below.
Select a Slice Pattern
All rows, column 2: Select entire column 2
PyTorch Code
import torch tensor = torch.arange(30).reshape(5, 6) # All rows, column 2 result = tensor[:, 2] print(result) print(result.shape) # Result: # Shape: (5, 1)
Slice Syntax: start:stop:step
start- beginning index (default: 0)stop- ending index, exclusive (default: end)step- increment (default: 1, negative = reverse)- Omitting a value uses the default
Boolean (Mask) Indexing
Boolean indexing uses a tensor of True/False values to select elements. This is extremely powerful for filtering data based on conditions.
Creating and Applying Masks
Boolean Indexing Creates Copies
Multi-Dimensional Boolean Indexing
1matrix = torch.randn(4, 5)
2
3# Mask on entire tensor (flattens result)
4mask = matrix > 0
5positive = matrix[mask]
6print(positive.shape) # 1D: number of positive elements
7
8# Mask on specific dimension
9# Select rows where first column > 0
10row_mask = matrix[:, 0] > 0
11selected_rows = matrix[row_mask]
12print(selected_rows.shape) # (num_matching_rows, 5)
13
14# Modify elements in-place using mask
15matrix[matrix < -1] = -1 # Clamp minimum to -1
16matrix[matrix > 1] = 1 # Clamp maximum to 1
17
18# torch.where for conditional assignment
19result = torch.where(matrix > 0, matrix, torch.zeros_like(matrix))
20# Like ReLU: keep positives, zero out negativesReal-World Example: GPT-2's Attention Mask
In transformer models like GPT-2, boolean masking prevents the model from attending to future tokens during training:
1# Causal mask for attention
2seq_len = 5
3# Create lower triangular mask (True = can attend)
4causal_mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
5print(causal_mask)
6# tensor([[ True, False, False, False, False],
7# [ True, True, False, False, False],
8# [ True, True, True, False, False],
9# [ True, True, True, True, False],
10# [ True, True, True, True, True]])
11
12# In attention: mask out future positions
13attention_scores = torch.randn(seq_len, seq_len)
14attention_scores = attention_scores.masked_fill(~causal_mask, float('-inf'))
15# After softmax, -inf becomes 0 probabilityFancy Indexing
Fancy indexing uses lists or tensors of indices to select elements in any order, potentially with repetition.
Advanced Fancy Indexing Patterns
1# Gather operation (common in embeddings)
2embeddings = torch.randn(1000, 256) # 1000 tokens, 256-dim
3token_ids = torch.tensor([42, 7, 256, 999])
4selected = embeddings[token_ids]
5print(selected.shape) # torch.Size([4, 256])
6
7# torch.gather for more control
8# Useful for selecting along specific dimension
9values = torch.tensor([[1, 2], [3, 4], [5, 6]])
10indices = torch.tensor([[0, 0], [1, 0], [0, 1]])
11gathered = torch.gather(values, dim=1, index=indices)
12print(gathered) # tensor([[1, 1], [4, 3], [5, 6]])
13
14# torch.take (flat indexing)
15flat_indices = torch.tensor([0, 3, 5])
16taken = torch.take(values, flat_indices)
17print(taken) # tensor([1, 4, 6]) - flat indexing into all elementsFancy Indexing Creates Copies
Views vs Copies: Critical Concepts
Understanding when operations create views (sharing memory) vs copies (independent memory) is crucial for correctness and performance. This concept is fundamental to efficient tensor programming.
What is a View?
A view is a tensor that shares its underlying data storage with another tensor. Modifying one affects the other. Views are memory-efficient because no data is copied.
| Operation | Returns View? | Notes |
|---|---|---|
| Basic slicing tensor[1:5] | ✅ View | Contiguous subset |
| view(), reshape() (if contiguous) | ✅ View | Just changes shape metadata |
| transpose(), permute(), T | ✅ View | Changes strides only |
| squeeze(), unsqueeze() | ✅ View | Adds/removes dim of size 1 |
| expand() | ✅ View | Broadcast to larger size |
| Boolean indexing tensor[mask] | ❌ Copy | Non-contiguous selection |
| Fancy indexing tensor[[0,2,1]] | ❌ Copy | Arbitrary element selection |
| clone() | ❌ Copy | Explicit copy |
| contiguous() | Maybe Copy | Copy if not already contiguous |
| reshape() (if non-contiguous) | ❌ Copy | Must copy to make contiguous |
Detecting Views
1x = torch.arange(12).reshape(3, 4)
2
3# Create a view via slicing
4y = x[1:3, 1:3]
5
6# Check if they share storage
7print(x.data_ptr() == y.data_ptr()) # True (for view)
8print(x.storage().data_ptr() == y.storage().data_ptr()) # True
9
10# Views share the underlying storage
11# but may have different data_ptr if offset
12
13# Modify the view - affects original!
14y[0, 0] = 999
15print(x[1, 1]) # tensor(999) - original changed!
16
17# Create a copy to avoid this
18z = x[1:3, 1:3].clone()
19z[0, 0] = -1
20print(x[1, 1]) # tensor(999) - unchangedThe Contiguity Problem
A tensor is contiguous when its elements are stored in memory in row-major (C-style) order without gaps. Some operations break contiguity:
1x = torch.arange(12).reshape(3, 4)
2print(x.is_contiguous()) # True
3
4# Transpose breaks contiguity
5y = x.T
6print(y.is_contiguous()) # False
7print(y.stride()) # (1, 4) instead of (3, 1)
8
9# view() fails on non-contiguous tensors
10try:
11 z = y.view(12) # RuntimeError!
12except RuntimeError as e:
13 print("Error:", e)
14
15# Solutions:
16z = y.contiguous().view(12) # Make contiguous, then view
17z = y.reshape(12) # reshape handles it automatically
18
19# Check if operations preserve contiguity
20a = x[:, ::2] # Strided access
21print(a.is_contiguous()) # False (gaps in memory)Performance Impact
.contiguous() once upfront.Quick Check
After x = torch.randn(4, 4); y = x.T[1:3, 1:3].clone(), does modifying y affect x?
Reshape Operations
Reshaping changes how tensor elements are organized into dimensions without changing the total number of elements. These operations are essential for preparing data for different layer types.
Core Reshape Functions
Transpose and Permute
For reordering dimensions (not just reshaping):
1# 2D: transpose swaps dimensions
2matrix = torch.randn(3, 4)
3transposed = matrix.T # or matrix.transpose(0, 1)
4print(transposed.shape) # torch.Size([4, 3])
5
6# 3D+: transpose swaps exactly two dimensions
7tensor = torch.randn(2, 3, 4)
8t = tensor.transpose(0, 2) # swap dim 0 and dim 2
9print(t.shape) # torch.Size([4, 3, 2])
10
11# permute: reorder ALL dimensions at once
12# Specify new order as arguments
13p = tensor.permute(2, 0, 1) # (dim2, dim0, dim1)
14print(p.shape) # torch.Size([4, 2, 3])
15
16# Common use: NCHW to NHWC conversion
17images = torch.randn(32, 3, 224, 224) # PyTorch format
18nhwc = images.permute(0, 2, 3, 1) # TensorFlow format
19print(nhwc.shape) # torch.Size([32, 224, 224, 3])Transpose Returns a View
Interactive: Reshape Visualizer
Explore how different reshape operations transform tensor shapes and understand when views vs copies are created.
shape: (3, 4)tensor.flatten()shape: (12)Flatten: Collapse all dimensions into one
import torch # Input: shape (3, 4) tensor = torch.arange(12).reshape(3, 4) # Operation: tensor.flatten() result = tensor.flatten() print(result.shape) # torch.Size([12]) print(result.numel()) # 12 (same total elements)
Real-World Applications
These tensor operations appear throughout deep learning. Here are some concrete examples from popular architectures:
1. CNN Feature Extraction (ResNet)
1# ResNet outputs feature maps of shape (batch, channels, H, W)
2# Before fully connected layer, we need to flatten
3
4class ResNetClassifier(nn.Module):
5 def __init__(self, num_classes=1000):
6 super().__init__()
7 self.resnet = torchvision.models.resnet50(pretrained=True)
8 self.fc = nn.Linear(2048, num_classes)
9
10 def forward(self, x):
11 # x: (batch, 3, 224, 224)
12 features = self.resnet.layer4(x) # (batch, 2048, 7, 7)
13
14 # Global average pooling
15 pooled = features.mean(dim=[2, 3]) # (batch, 2048)
16
17 # Alternative: flatten
18 # flat = features.flatten(start_dim=1) # (batch, 2048*7*7)
19
20 return self.fc(pooled)2. Multi-Head Attention (Transformer)
One of the most important reshape patterns in modern AI, used in GPT, BERT, and Vision Transformers:
1class MultiHeadAttention(nn.Module):
2 def __init__(self, d_model=512, num_heads=8):
3 super().__init__()
4 self.d_model = d_model
5 self.num_heads = num_heads
6 self.head_dim = d_model // num_heads
7
8 self.W_q = nn.Linear(d_model, d_model)
9 self.W_k = nn.Linear(d_model, d_model)
10 self.W_v = nn.Linear(d_model, d_model)
11 self.W_o = nn.Linear(d_model, d_model)
12
13 def forward(self, x):
14 batch_size, seq_len, _ = x.shape
15
16 # Project queries, keys, values
17 Q = self.W_q(x) # (batch, seq, d_model)
18 K = self.W_k(x)
19 V = self.W_v(x)
20
21 # Split into heads: reshape + transpose
22 # (batch, seq, d_model) -> (batch, seq, heads, head_dim)
23 Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
24 K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
25 V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)
26
27 # Transpose for attention: (batch, heads, seq, head_dim)
28 Q = Q.transpose(1, 2)
29 K = K.transpose(1, 2)
30 V = V.transpose(1, 2)
31
32 # Compute attention (scaled dot-product)
33 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
34 attn = F.softmax(scores, dim=-1)
35 context = torch.matmul(attn, V) # (batch, heads, seq, head_dim)
36
37 # Merge heads back: transpose + reshape
38 context = context.transpose(1, 2).contiguous() # (batch, seq, heads, head_dim)
39 context = context.view(batch_size, seq_len, self.d_model) # (batch, seq, d_model)
40
41 return self.W_o(context)3. Batch Normalization Statistics
1# Computing batch statistics across spatial dimensions
2features = torch.randn(32, 64, 28, 28) # (N, C, H, W)
3
4# Mean/var per channel across batch and spatial dims
5# Need to normalize over N, H, W (dims 0, 2, 3)
6mean = features.mean(dim=(0, 2, 3), keepdim=True) # (1, C, 1, 1)
7var = features.var(dim=(0, 2, 3), keepdim=True)
8
9# keepdim=True maintains shape for broadcasting
10normalized = (features - mean) / torch.sqrt(var + 1e-5)
11
12# Alternative: reshape for easier computation
13# (N, C, H, W) -> (N, C, H*W) -> compute over (N, H*W)
14reshaped = features.flatten(start_dim=2) # (32, 64, 784)
15mean_alt = reshaped.mean(dim=(0, 2)) # (64,)4. Image Data Augmentation
1# Random crop and flip using indexing
2def random_crop_flip(images, crop_size=200):
3 """
4 images: (N, C, H, W) tensor
5 Returns cropped and randomly flipped images
6 """
7 N, C, H, W = images.shape
8
9 # Random crop coordinates
10 top = torch.randint(0, H - crop_size, (N,))
11 left = torch.randint(0, W - crop_size, (N,))
12
13 # Use indexing for each image (could vectorize with gather)
14 crops = []
15 for i in range(N):
16 crop = images[i, :, top[i]:top[i]+crop_size, left[i]:left[i]+crop_size]
17
18 # Random horizontal flip using negative stride
19 if torch.rand(1) > 0.5:
20 crop = crop[:, :, ::-1] # Flip width dimension
21
22 crops.append(crop)
23
24 return torch.stack(crops) # (N, C, crop_size, crop_size)Common Pitfalls and Debugging
1. Accidental In-Place Modification via Views
1# ❌ Bug: Modifying a view affects the original
2def bad_normalize(tensor):
3 row = tensor[0]
4 row /= row.sum() # In-place! Modifies tensor[0]!
5 return row
6
7# ✅ Fix: Clone if you need independence
8def good_normalize(tensor):
9 row = tensor[0].clone()
10 row /= row.sum() # Only affects the copy
11 return row2. Shape Mismatch in Matrix Operations
1# ❌ Bug: Forgetting to add batch dimension
2image = torch.randn(3, 224, 224) # Single image
3output = model(image) # Error! Model expects (N, C, H, W)
4
5# ✅ Fix: Add batch dimension
6image = image.unsqueeze(0) # (1, 3, 224, 224)
7output = model(image)
8
9# ❌ Bug: Wrong dimension for linear layer
10features = torch.randn(32, 7, 7, 256) # (N, H, W, C)
11linear = nn.Linear(256, 10)
12# output = linear(features) # Works but wrong!
13
14# ✅ Fix: Permute to put features last
15features = features.permute(0, 3, 1, 2) # (N, C, H, W)
16features = features.flatten(1) # (N, C*H*W)
17output = linear(features)3. Non-Contiguous Tensor Errors
1# ❌ Bug: view() on non-contiguous tensor
2x = torch.randn(4, 4)
3y = x.T # Transpose creates non-contiguous view
4z = y.view(16) # RuntimeError!
5
6# ✅ Fix options:
7z = y.contiguous().view(16) # Make contiguous first
8z = y.reshape(16) # reshape handles it
9z = y.flatten() # Same as reshape(-1)
10
11# Check before view()
12if y.is_contiguous():
13 z = y.view(16)
14else:
15 z = y.reshape(16)4. Broadcasting Surprises
1# ❌ Bug: Unexpected broadcasting
2a = torch.randn(3, 4)
3b = torch.randn(4) # 1D tensor
4
5# This broadcasts b to each row - is this intended?
6c = a + b # Works: (3, 4) + (4,) -> (3, 4)
7
8# But this might not be what you wanted:
9b2 = torch.randn(3) # Wrong dimension!
10# c2 = a + b2 # Error! (3, 4) + (3,) doesn't broadcast
11
12# ✅ Fix: Be explicit about dimensions
13b2 = b2.unsqueeze(1) # (3, 1)
14c2 = a + b2 # Now broadcasts correctly5. Debugging Shape Issues
1def debug_shapes(tensor, name="tensor"):
2 """Helper function to print tensor info"""
3 print(f"{name}: shape={tensor.shape}, "
4 f"dtype={tensor.dtype}, "
5 f"device={tensor.device}, "
6 f"contiguous={tensor.is_contiguous()}")
7
8# Use throughout your code
9x = torch.randn(32, 3, 224, 224)
10debug_shapes(x, "input")
11
12x = model.conv1(x)
13debug_shapes(x, "after conv1")
14
15# For complex pipelines, use hooks
16def shape_hook(module, input, output):
17 print(f"{module.__class__.__name__}: "
18 f"{input[0].shape} -> {output.shape}")
19
20for name, layer in model.named_modules():
21 layer.register_forward_hook(shape_hook)Knowledge Check
Test your understanding of tensor indexing, slicing, and reshaping with this comprehensive quiz.
Given a tensor with shape (4, 5, 6), what does tensor[2, :, 3] return?
tensor = torch.randn(4, 5, 6) result = tensor[2, :, 3]
Summary
This section covered the essential tensor manipulation operations that form the foundation of all deep learning code:
| Concept | Key Points | Common Use Cases |
|---|---|---|
| Basic Indexing | Zero-based, reduces dimensionality | Extract rows, columns, elements |
| Slicing | start:stop:step, negative indices | Extract subregions, reverse order |
| Boolean Indexing | Mask-based selection, returns copy | Filter by condition, thresholding |
| Fancy Indexing | Arbitrary order, returns copy | Reorder elements, embeddings |
| Views | Share memory, no copy overhead | Slicing, transpose, reshape (if contiguous) |
| Copies | Independent memory | clone(), boolean/fancy indexing |
| reshape/view | Change dimensions, preserve elements | Flatten for FC, batch processing |
| transpose/permute | Reorder dimensions | NCHW↔NHWC, attention heads |
| squeeze/unsqueeze | Add/remove size-1 dims | Batch dimension, broadcasting |
Key Takeaways
- Views share memory: Basic slicing, transpose, and reshape (when contiguous) return views. Modifying a view modifies the original.
- Boolean and fancy indexing create copies: Always independent memory, safe to modify.
- view() requires contiguity: Use reshape() for safety, or call contiguous() first.
- -1 infers dimensions: In reshape, -1 means “calculate this dimension automatically.”
- Ellipsis (...) is your friend: Use it to index specific dimensions while keeping others.
Exercises
Conceptual Questions
- Explain why
tensor.T.view(-1)might fail whiletensor.T.reshape(-1)succeeds. - Given a batch of images with shape (N, H, W, C) in NHWC format, write the permute operation to convert to NCHW format.
- Why does boolean indexing always return a 1D tensor, regardless of the input shape?
Coding Exercises
- Checkerboard Selection: Create a 8×8 tensor and use slicing to select only the “white squares” of a checkerboard pattern (every other element in a diagonal pattern).
- Batch Cropping: Given a batch of images
torch.randn(16, 3, 224, 224), use slicing to extract the center 100×100 pixels from each image. - Attention Reshape: Implement the head-splitting operation for multi-head attention: take a tensor of shape (batch, seq_len, d_model) and reshape it to (batch, num_heads, seq_len, head_dim) where d_model = num_heads × head_dim.
- Top-K Selection: Given a tensor of scores, use boolean indexing to select only the elements that are in the top 10% of values.
Challenge Exercise
Implement Efficient Batched Indexing: Write a function that takes a 2D tensor and a tensor of row indices, and efficiently extracts the corresponding rows using fancy indexing. Compare performance with a loop-based approach for large tensors.
1def batched_row_select(tensor, indices):
2 """
3 Args:
4 tensor: shape (M, N)
5 indices: shape (K,) with values in [0, M)
6 Returns:
7 shape (K, N) tensor with selected rows
8 """
9 # Your implementation here
10 pass
11
12# Test
13tensor = torch.randn(1000, 256)
14indices = torch.randint(0, 1000, (100,))
15result = batched_row_select(tensor, indices)
16assert result.shape == (100, 256)Performance Hint
tensor[indices] (fancy indexing) with torch.index_select(tensor, 0, indices). Benchmark both approaches with varying tensor and index sizes.In the next section, we'll explore GPU computing in PyTorch, learning how to leverage CUDA for massive speedups in tensor operations.