Learning Objectives
By the end of this section, you will be able to:
- Understand tensors as the fundamental data structure underlying all deep learning computations
- Visualize tensor dimensions and understand the relationship between shape, rank, and total elements
- Create tensors using various PyTorch functions and understand when to use each
- Master memory layout including strides, contiguity, and the difference between views and copies
- Choose appropriate data types for different use cases, from training to deployment
- Manipulate tensor shapes using reshape, view, squeeze, unsqueeze, and transpose operations
Why This Matters: Tensors are to deep learning what atoms are to chemistry. Every neural network weight, every input image, every gradient—they are all tensors. Mastering tensor operations is the foundation upon which all deep learning is built.
What is a Tensor?
A tensor is a multi-dimensional array—a generalization of vectors and matrices to any number of dimensions. In deep learning, tensors are the universal data structure for representing everything from input data to model parameters to computed gradients.
The Mathematical Perspective
Mathematically, a tensor of rank is an element of the tensor product of vector spaces:
But for practical deep learning, think of a tensor as simply a container for numbers organized in a specific shape, with optimized operations for numerical computing.
The Computational Perspective
In PyTorch, a tensor is:
- A contiguous block of memory storing numerical values
- A shape describing how to interpret that memory as a multi-dimensional array
- Metadata including data type, device (CPU/GPU), and gradient tracking information
- A set of operations for efficient mathematical computation
| Concept | NumPy Term | PyTorch Term | Example |
|---|---|---|---|
| Multi-dimensional array | ndarray | Tensor | Model weights, images, embeddings |
| Number of dimensions | ndim | dim() | Image has 3 (C, H, W) |
| Size of each dimension | shape | shape or size() | (3, 224, 224) for RGB image |
| Total number of elements | size | numel() | 3 × 224 × 224 = 150,528 |
How Computer Memory Works
Before diving deeper into tensors, let's understand a fundamental concept: how does a computer actually store data? This knowledge is crucial for understanding tensor operations, memory layout, and why certain operations are faster than others.
The Key Insight: Linear Memory
Computer memory (RAM) is essentially a long, linear sequence of bytes—like a very long row of numbered mailboxes. Each "mailbox" (memory address) can hold one byte of data. When we create arrays, matrices, or multi-dimensional tensors, the computer must store all that data in this one-dimensional linear structure.
The challenge is: how do we map multi-dimensional data (like a 2D image or 3D video) into this linear sequence? The answer is a systematic flattening process that preserves the ability to efficiently access any element.
Why This Matters: Understanding memory layout helps you:
- Write more efficient code by accessing data in cache-friendly patterns
- Understand why some tensor operations are "free" (views) while others require copying data
- Debug issues with tensor strides and contiguity
- Optimize GPU memory usage in deep learning
Use the interactive visualizer below to explore how 1D strings, 2D matrices, and 3D tensors are stored in linear memory. Watch how the same data can be viewed in multiple dimensions while living in a single row of memory addresses.
Understanding Computer Memory
How Computer Memory Works
Computer memory (RAM) is a linear sequence of bytes. Each byte has a unique address (like a house number). When we store arrays or matrices, the data is placed contiguously (side by side) in memory.
Key Insight: Contiguous Storage
Each character is stored in adjacent memory locations. Address 00 → 01 → 02 → ... The string ends with a null character (\0) to mark its end. To access str[i], the computer simply goes to address: base + i
The Universal Pattern
No matter how many dimensions, data is always stored as a contiguous block in memory:
Quick Check
For a 2D matrix with 4 rows and 5 columns stored in row-major order, what is the flat memory index of element [2, 3]?
Tensor Dimensions and Shapes
Understanding tensor dimensions is crucial because operations in deep learning are defined in terms of tensor shapes. Let's explore the hierarchy from scalars to higher-dimensional tensors.
Scalar (0-dimensional tensor)
A scalar contains a single number. In PyTorch, it has shape () (an empty tuple) and dimension 0.
1import torch
2
3# Creating a scalar
4scalar = torch.tensor(42.0)
5print(scalar.shape) # torch.Size([])
6print(scalar.dim()) # 0
7print(scalar.item()) # 42.0 (extract Python number)Scalars are commonly used for loss values, learning rates, and single predictions.
Vector (1-dimensional tensor)
A vector is a 1D array of numbers. Shape is (n,) where is the number of elements.
1vector = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
2print(vector.shape) # torch.Size([5])
3print(vector.dim()) # 1Vectors represent word embeddings, feature vectors, and biases in neural networks.
Matrix (2-dimensional tensor)
A matrix is a 2D array with shape (m, n) representing rows and columns.
1matrix = torch.tensor([[1, 2, 3],
2 [4, 5, 6]])
3print(matrix.shape) # torch.Size([2, 3])
4print(matrix.dim()) # 2Matrices represent weight matrices, batch of feature vectors, and grayscale images.
3D Tensor and Beyond
Higher-dimensional tensors extend this pattern. A 3D tensor with shape (d, m, n) can be thought of as stacked matrices.
| Rank | Name | Shape Example | Use Case |
|---|---|---|---|
| 0 | Scalar | () | Loss value, learning rate |
| 1 | Vector | (512,) | Word embedding, bias |
| 2 | Matrix | (64, 784) | Batch of flattened images |
| 3 | 3D Tensor | (32, 28, 28) | Batch of grayscale images |
| 4 | 4D Tensor | (32, 3, 224, 224) | Batch of RGB images (B, C, H, W) |
| 5 | 5D Tensor | (8, 16, 3, 224, 224) | Batch of video clips (B, T, C, H, W) |
Convention Matters
Shape Cookbook
Here's a quick reference mapping common deep learning data types to their tensor shapes. Memorizing these conventions will help you debug shape mismatches and design architectures.
| Data Type | Shape | Dimension Names | Typical Values |
|---|---|---|---|
| Single Image (grayscale) | (H, W) | height, width | (28, 28) MNIST |
| Single Image (RGB) | (C, H, W) | channels, height, width | (3, 224, 224) ImageNet |
| Batch of Images | (N, C, H, W) | batch, channels, height, width | (32, 3, 224, 224) |
| Text Sequence (token IDs) | (T,) | sequence length | (512,) tokens |
| Batch of Sequences | (N, T) | batch, sequence | (32, 512) tokens |
| Word Embeddings | (T, D) | sequence, embedding dim | (512, 768) BERT |
| Batch of Embeddings | (N, T, D) | batch, sequence, embed | (32, 512, 768) |
| Attention Mask | (N, T) | batch, sequence | (32, 512) 0s and 1s |
| Attention Weights | (N, H, T, T) | batch, heads, query, key | (32, 12, 512, 512) |
| Video Clip | (N, T, C, H, W) | batch, frames, channels, h, w | (8, 16, 3, 224, 224) |
| Audio Waveform | (N, C, L) | batch, channels, length | (32, 1, 16000) 1s mono |
| Spectrogram | (N, C, F, T) | batch, channels, freq, time | (32, 1, 128, 256) |
| Point Cloud | (N, P, 3) | batch, points, xyz | (16, 1024, 3) |
| Graph Node Features | (N, F) | nodes, features | (1000, 128) |
Quick Shape Debug
tensor.shape at each step of your forward pass. The error usually tells you exactly which dimensions don't match.Creating Tensors in PyTorch
PyTorch provides many ways to create tensors. Choosing the right method depends on your use case.
Memory Sharing with NumPy
torch.from_numpy(), the tensor and NumPy array share the same memory. Modifying one will modify the other! Use torch.tensor() if you need an independent copy.Essential Tensor Attributes
Every PyTorch tensor carries important metadata that determines how it behaves.
1x = torch.randn(3, 4, dtype=torch.float32, device='cpu')
2
3# Shape and size
4print(x.shape) # torch.Size([3, 4])
5print(x.size()) # torch.Size([3, 4]) - same as shape
6print(x.size(0)) # 3 - size of first dimension
7print(x.size(-1)) # 4 - size of last dimension
8
9# Number of dimensions and elements
10print(x.dim()) # 2 (also called rank)
11print(x.ndim) # 2 (same as dim())
12print(x.numel()) # 12 (total elements: 3 × 4)
13
14# Data type
15print(x.dtype) # torch.float32
16
17# Device (CPU or GPU)
18print(x.device) # cpu
19
20# Memory layout
21print(x.stride()) # (4, 1) - strides for each dimension
22print(x.is_contiguous()) # True
23
24# Gradient tracking
25print(x.requires_grad) # False (not tracking gradients)
26
27# Underlying storage
28print(x.storage()) # Raw memory storage
29print(x.data_ptr()) # Memory address| Attribute | Returns | Description |
|---|---|---|
| shape / size() | torch.Size | Dimensions of the tensor |
| dim() / ndim | int | Number of dimensions (rank) |
| numel() | int | Total number of elements |
| dtype | torch.dtype | Data type (float32, int64, etc.) |
| device | torch.device | Where tensor lives (cpu, cuda:0) |
| stride() | tuple | Memory jumps per dimension |
| requires_grad | bool | Whether gradients are tracked |
Data Types (dtypes)
Choosing the right data type affects memory usage, computation speed, and numerical precision. PyTorch supports many dtypes for different use cases.
Tensor Data Types (dtypes)
torch.float32Use case: Default for training. Balance of range and precision.
# Create tensor with specific dtype
x = torch.tensor([1.0, 2.0, 3.0],
dtype=torch.float32)
# Convert existing tensor
y = x.to(torch.float32)
# Check dtype
print(x.dtype) # torch.float32float32 for training,float16 or bfloat16 for mixed precision, and int8 for deployment quantization.Type Conversion
1x = torch.tensor([1, 2, 3]) # Default: int64
2
3# Convert to different types
4x_float = x.float() # torch.float32
5x_half = x.half() # torch.float16
6x_double = x.double() # torch.float64
7x_int = x.int() # torch.int32
8
9# Using .to() for explicit conversion
10x_bf16 = x.to(torch.bfloat16)
11x_gpu = x.to('cuda') # Also moves to GPU
12
13# Check compatibility
14x = torch.randn(3) # float32
15y = torch.randint(0, 10, (3,)) # int64
16# z = x + y # Works! y is automatically promoted to float32Mixed Precision Training
torch.autocast for automatic mixed precision training. It intelligently uses float16 where safe and float32 where needed, often giving 2x speedup with minimal accuracy loss.Named Tensors
Named tensors allow you to associate names with tensor dimensions, making code more readable and less error-prone. Instead of remembering that dimension 0 is batch and dimension 1 is channel, you can name them explicitly.
1import torch
2
3# Create a named tensor
4images = torch.randn(32, 3, 224, 224, names=('N', 'C', 'H', 'W'))
5print(images.names) # ('N', 'C', 'H', 'W')
6
7# Operations preserve names
8mean_per_channel = images.mean(dim='H').mean(dim='W')
9print(mean_per_channel.names) # ('N', 'C')
10
11# Align tensors by name (broadcasting by name)
12bias = torch.randn(3, names=('C',))
13# This broadcasts correctly regardless of dimension order
14result = images + bias.align_as(images)
15
16# Rename dimensions
17renamed = images.rename(N='batch', C='channels')
18print(renamed.names) # ('batch', 'channels', 'H', 'W')
19
20# Remove names when needed
21unnamed = images.rename(None) # Back to regular tensorNamed Tensors Status
rename(None) when passing to libraries.Sparse Tensors
Sparse tensors efficiently store tensors where most elements are zero. Instead of storing all values, they only store non-zero elements and their indices. This is crucial for graph neural networks, recommender systems, and NLP applications with sparse features.
1import torch
2
3# COO format (Coordinate format) - most common
4indices = torch.tensor([[0, 1, 2], # row indices
5 [0, 2, 1]]) # col indices
6values = torch.tensor([1.0, 2.0, 3.0])
7sparse_coo = torch.sparse_coo_tensor(indices, values, size=(3, 3))
8
9# Visualize: this represents
10# [[1, 0, 0],
11# [0, 0, 2],
12# [0, 3, 0]]
13
14# CSR format (Compressed Sparse Row) - efficient for row operations
15sparse_csr = sparse_coo.to_sparse_csr()
16
17# Convert between sparse and dense
18dense = sparse_coo.to_dense()
19back_to_sparse = dense.to_sparse()
20
21# Sparse matrix multiplication
22weight = torch.randn(3, 4)
23result = torch.sparse.mm(sparse_coo, weight) # Sparse @ Dense
24
25# Create sparse tensor directly from adjacency list (GNN use case)
26edge_index = torch.tensor([[0, 1, 1, 2], # source nodes
27 [1, 0, 2, 1]]) # target nodes
28edge_weight = torch.ones(4)
29adj = torch.sparse_coo_tensor(edge_index, edge_weight, size=(3, 3))| Format | Best For | Memory | Common Operations |
|---|---|---|---|
| COO | Construction, conversion | O(nnz) | Creation, to_dense() |
| CSR | Row slicing, matmul | O(nnz + rows) | mm(), mv() |
| CSC | Column slicing | O(nnz + cols) | Column access |
When to Use Sparse Tensors
Memory Layout and Strides
Understanding how tensors are stored in memory is crucial for writing efficient code. PyTorch stores tensors in row-major (C-style) contiguous order by default.
Contiguous Memory
A tensor is contiguous if its elements are stored in a single, unbroken block of memory, arranged in row-major order. Consider a 2×3 matrix:
In row-major order, memory contains: [1, 2, 3, 4, 5, 6]
Strides
Strides tell PyTorch how many elements to skip in memory to move one position along each dimension.
1x = torch.tensor([[1, 2, 3],
2 [4, 5, 6]])
3print(x.stride()) # (3, 1)
4# Stride (3, 1) means:
5# - Move along rows (dim 0): skip 3 elements
6# - Move along cols (dim 1): skip 1 element
7
8# After transpose, strides change but data doesn't move!
9y = x.T
10print(y.shape) # torch.Size([3, 2])
11print(y.stride()) # (1, 3) - strides are swapped!
12print(y.is_contiguous()) # FalseContiguity in Practice
Performance Corner
Memory layout directly impacts performance. Non-contiguous tensors require more cache misses and can't use optimized memory access patterns. Here's a micro-benchmark demonstrating the difference:
1import torch
2import time
3
4def benchmark(fn, name, warmup=10, runs=100):
5 for _ in range(warmup):
6 fn()
7 torch.cuda.synchronize() if torch.cuda.is_available() else None
8
9 start = time.perf_counter()
10 for _ in range(runs):
11 fn()
12 torch.cuda.synchronize() if torch.cuda.is_available() else None
13 elapsed = (time.perf_counter() - start) / runs * 1000
14 print(f"{name}: {elapsed:.3f} ms")
15
16x = torch.randn(1000, 1000)
17
18# Contiguous reshape
19def contiguous_reshape():
20 return x.reshape(100, 10000).sum()
21
22# Non-contiguous: transpose then reshape (requires copy)
23def noncontiguous_reshape():
24 return x.T.reshape(100, 10000).sum()
25
26# Fix: make contiguous first
27def fixed_reshape():
28 return x.T.contiguous().reshape(100, 10000).sum()
29
30benchmark(contiguous_reshape, "Contiguous reshape") # ~0.3 ms
31benchmark(noncontiguous_reshape, "Non-contiguous reshape") # ~0.8 ms
32benchmark(fixed_reshape, "Explicit contiguous()") # ~0.5 msInference Optimization
During inference, you don't need gradients. Disabling them reduces memory usage and speeds up computation:
1# Option 1: Context manager (recommended)
2with torch.no_grad():
3 output = model(input)
4
5# Option 2: Inference mode (faster, stricter)
6with torch.inference_mode():
7 output = model(input)
8
9# Option 3: Global setting (useful for evaluation loops)
10torch.set_grad_enabled(False)
11# ... evaluation code ...
12torch.set_grad_enabled(True) # Remember to re-enable!
13
14# Option 4: Decorator for inference functions
15@torch.inference_mode()
16def predict(model, x):
17 return model(x)| Method | Gradient Tracking | In-place Safety | Speed |
|---|---|---|---|
| Default | Yes | Errors on conflict | Baseline |
| no_grad() | No | Allowed | Faster |
| inference_mode() | No | Stricter (no views) | Fastest |
When to Use What
inference_mode() for production inference (fastest). Use no_grad() when you need to create views of outputs or do mixed training/inference. Use set_grad_enabled(False) for evaluation loops where you toggle frequently.More Performance Topics
Interactive: Memory Layout
Visualize how tensor elements are mapped to linear memory. Toggle between row-major and column-major layouts to understand how strides work.
Tensor Memory Layout
Common Tensor Shapes in Deep Learning
| Data Type | Tensor Shape | Example |
|---|---|---|
| Grayscale Image | (H, W) | (28, 28) |
| RGB Image | (C, H, W) | (3, 224, 224) |
| Batch of Images | (N, C, H, W) | (32, 3, 224, 224) |
| Video Clip | (N, T, C, H, W) | (8, 16, 3, 224, 224) |
Quick Check
A tensor with shape (4, 3) in row-major order has what strides?
Shape Inspection
Before manipulating tensors, you need to inspect their shapes. PyTorch provides several methods to understand tensor dimensions.
1x = torch.randn(2, 3, 4)
2
3# Shape and size (equivalent)
4print(x.shape) # torch.Size([2, 3, 4])
5print(x.size()) # torch.Size([2, 3, 4])
6print(x.size(0)) # 2 - size of first dimension
7print(x.size(-1)) # 4 - size of last dimension
8
9# Number of dimensions
10print(x.dim()) # 3
11print(x.ndim) # 3 (same as dim())
12
13# Total number of elements
14print(x.numel()) # 24 (2 × 3 × 4)
15
16# Quick reshape preview
17print(x.reshape(6, 4).shape) # torch.Size([6, 4])
18print(x.reshape(-1).shape) # torch.Size([24]) - flattenFull Reshaping Coverage in Section 4.3
reshape, view, squeeze, unsqueeze, transpose, permute) are covered comprehensively in Section 4.3: Indexing, Slicing, and Reshaping.Broadcasting Rules
Broadcasting allows operations between tensors of different shapes by automatically expanding dimensions. This is one of the most powerful features in PyTorch, enabling concise code without explicit loops.
The core rules are simple—PyTorch compares shapes from the rightmost dimension:
- Dimensions are compatible if they are equal, or one of them is 1
- Missing dimensions are treated as size 1
- The output shape is the element-wise maximum of input shapes
1# Simple broadcasting example
2a = torch.randn(3, 4) # Shape: (3, 4)
3b = torch.randn(4) # Shape: (4,) → broadcasts to (3, 4)
4c = a + b # Shape: (3, 4) - b added to each row
5
6# Scalar broadcasts to any shape
7x = torch.randn(2, 3, 4)
8y = x * 2 # 2 broadcasts to (2, 3, 4)Deep Dive in Next Section
Interactive: Tensor Playground
Experiment with tensor operations in real-time. Select an input shape, apply operations, and see how the tensor transforms. Switch between 2D grid view and immersive 3D visualization with mouse controls.
Interactive Tensor Operations
Change shape while preserving data order in memory
xshape=[2, 3]x = torch.arange(1, 7).reshape(2, 3)
# x.shape = torch.Size([2, 3])
y = x.reshape(3, 2)
Views vs Copies
A critical concept in PyTorch: some operations return views (sharing memory with the original) while others return copies (independent data). This affects both correctness and performance.
1x = torch.tensor([1, 2, 3, 4])
2y = x.view(2, 2) # y is a VIEW of x - shares memory!
3
4y[0, 0] = 99
5print(x) # tensor([99, 2, 3, 4]) - x changed too!
6
7# Use clone() when you need independence
8z = x.clone()
9z[0] = 0
10print(x) # tensor([99, 2, 3, 4]) - x unchangedCommon Bug
.clone() when you need an independent copy.Comprehensive Coverage in Section 4.3
detach() method.Device Management
PyTorch tensors can live on different devices (CPU or GPU). Operations require tensors to be on the same device.
1# Check if GPU is available
2print(torch.cuda.is_available()) # True if GPU available
3
4# Create tensor on GPU
5x_gpu = torch.randn(3, 4, device='cuda')
6
7# Move tensor between devices
8x_cpu = x_gpu.to('cpu')
9x_back = x_cpu.to('cuda')
10
11# Device-agnostic pattern (recommended)
12device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13x = torch.randn(3, 4, device=device)Comprehensive GPU Coverage in Section 4.4
Advanced Tips
Danger Zone: as_strided()
torch.as_strided() gives you direct control over tensor strides, enabling powerful operations like sliding windows. However, it's unsafe—you can easily read garbage memory or cause segfaults:
1# ⚠️ as_strided is powerful but dangerous
2x = torch.arange(10) # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
3
4# Create sliding window view (safe example)
5# Window size 3, stride 1
6windows = torch.as_strided(
7 x,
8 size=(8, 3), # 8 windows of size 3
9 stride=(1, 1) # Move 1 element between windows
10)
11# Result: [[0,1,2], [1,2,3], [2,3,4], ...]
12
13# ❌ DANGER: Wrong strides can read garbage memory!
14# bad = torch.as_strided(x, size=(100, 100), stride=(1, 1))
15# This reads way beyond x's actual memory!
16
17# ✅ SAFER ALTERNATIVES:
18# Use unfold() for sliding windows
19windows = x.unfold(0, 3, 1) # Same result, safer!
20
21# Use tensor indexing for views
22# Use reshape/view for dimension changesas_strided Safety
as_strided() unless you truly understand memory layout. Prefer safer alternatives: unfold() for sliding windows, view() for reshaping, expand() for broadcasting.Common Pitfalls
1. Modifying Views Unintentionally
1# ❌ Bug: Unintentionally modifying original
2x = torch.randn(10)
3y = x[:5]
4y[0] = 999 # Oops! x[0] is now 999 too!
5
6# ✅ Fix: Clone when you need independence
7y = x[:5].clone()
8y[0] = 999 # x is unchanged2. Forgetting contiguous()
1# ❌ Bug: view() on non-contiguous tensor
2x = torch.randn(3, 4).T # Transpose makes it non-contiguous
3y = x.view(-1) # RuntimeError!
4
5# ✅ Fix: Make contiguous first
6y = x.contiguous().view(-1)
7# Or use reshape which handles this automatically
8y = x.reshape(-1)3. Device Mismatch
1# ❌ Bug: Operations between different devices
2x_cpu = torch.randn(3)
3y_gpu = torch.randn(3, device='cuda')
4z = x_cpu + y_gpu # RuntimeError!
5
6# ✅ Fix: Move to same device first
7z = x_cpu.to('cuda') + y_gpu4. In-place Operations Breaking Autograd
1# ❌ Bug: In-place ops on leaf tensors with grad
2x = torch.randn(3, requires_grad=True)
3x.add_(1) # RuntimeError!
4
5# ✅ Fix: Use out-of-place operations
6x = torch.randn(3, requires_grad=True)
7y = x + 1 # Creates new tensor, preserves grad graphKnowledge Check
Test your understanding of tensors with this quiz. Each question has a detailed explanation.
Tensor Quiz
What is the shape of a tensor created by torch.randn(3, 4, 5)?
Summary
Tensors are the fundamental data structure of deep learning. Here's what we covered:
| Concept | Key Point | PyTorch API |
|---|---|---|
| Tensor Basics | Multi-dimensional arrays with metadata | torch.tensor(), shape, dtype, device |
| Dimensions | Rank 0-4+ for scalar to batch data | dim(), size(), numel() |
| Data Types | float32 for training, int8 for deploy | .float(), .half(), .to(dtype) |
| Memory Layout | Row-major, strides describe access | stride(), is_contiguous() |
| Reshaping | Change shape without copying when possible | reshape(), view(), squeeze() |
| Views vs Copies | Views share memory, copies are independent | clone(), contiguous() |
| Devices | CPU vs GPU, must match for operations | .to(device), .cuda(), .cpu() |
Exercises
Conceptual Questions
- Explain the difference between a 1D tensor with shape (3,) and a 2D tensor with shape (3, 1). How do they differ in behavior for matrix multiplication?
- Why does
transpose()not copy data? What changes instead? - When would you use
float16instead offloat32? What are the trade-offs?
Coding Exercises
- Tensor Manipulation: Create a 4D tensor of shape (2, 3, 4, 5). Reshape it to (6, 20), then back to (2, 3, 4, 5). Verify that the values are preserved.
- Broadcasting Practice: Given tensors of shapes (3, 1) and (1, 4), predict the result shape of their addition. Verify with PyTorch.
- Memory Investigation: Create a tensor, transpose it, and verify that
is_contiguous()returns False. Then make it contiguous and check that the data is the same. - View vs Clone: Write code that demonstrates the difference between view and clone by modifying one and checking if the other changes.
Challenge Exercise
Implement Einstein Summation: Using only basic tensor operations (no torch.einsum), implement batch matrix multiplication for tensors of shapes (B, M, K) and (B, K, N) to produce (B, M, N).
Solution Hints
- Use unsqueeze to add dimensions for broadcasting
- Element-wise multiply with broadcasting
- Sum over the contracted dimension K
In the next section, we'll explore tensor operations in depth, including arithmetic, comparison, reduction, and linear algebra operations.