Learning Objectives
By the end of this section, you will be able to:
- Explain why we split training data into mini-batches instead of using the full dataset at once
- Derive the key mathematical property: mini-batch gradients are unbiased estimates of the true gradient with variance
- Implement a complete mini-batch training loop from scratch in NumPy
- Use PyTorch's Dataset and DataLoader to handle batching, shuffling, and parallel loading
- Choose appropriate batch sizes for different hardware and training scenarios
Where We Left Off
In Chapter 10, we built multi-layer perceptrons (MLPs) that can learn complex functions by stacking layers. In Chapters 7-9, we learned how data flows forward through a network, how gradients flow backward via backpropagation, and how optimizers like SGD and Adam use those gradients to update weights.
But there is a question we have been quietly glossing over: how do we actually feed data to the network? In all our examples so far, we computed gradients on the entire dataset at once. For 8 or 16 samples, that is fine. But real datasets have thousands to billions of samples. Computing the gradient on the entire dataset before making a single weight update is both computationally wasteful and, surprisingly, mathematically suboptimal.
This section introduces the concept of mini-batch training — the practical engine that powers every modern neural network, from small classifiers to GPT-4.
The Training Data Challenge
Consider a concrete scenario. You are training a model on ImageNet, which has 1.2 million images. Each image is 224×224×3 = 150,528 floats. At 4 bytes per float, the raw input data alone is 720 GB. This does not fit in GPU memory (typically 16-80 GB), and computing the gradient over all 1.2 million samples before making a single weight update would be absurdly slow.
Even if memory were unlimited, there is a deeper mathematical reason not to use the full dataset for every gradient computation. Recall from Chapter 8 that the gradient of the loss over the entire dataset is:
This is an average of per-sample gradients. And here is the key insight: you do not need to compute every single term in this sum to get a useful estimate of which direction to move. A random subset of samples gives you a noisy but unbiased approximation of the true gradient — and that approximation is often good enough.
The Core Insight: If you are standing on a mountain in fog and ask 100 people which way is downhill, you get a reliable answer. You do not need to ask all 10,000 people on the mountain. Asking a small random group is faster and points you in roughly the same direction.
Three Flavors of Gradient Descent
Based on how many samples we use to compute each gradient update, there are three regimes.
Full-Batch Gradient Descent
Use the entire dataset for every gradient computation:
- Advantage: The gradient is exact — no noise, smooth convergence
- Disadvantage: Must process all N samples before one update. For N = 1,000,000 this is extremely slow
- Disadvantage: Requires all data in memory at once
- Disadvantage: Smooth gradients can get trapped in sharp local minima
Stochastic Gradient Descent (B = 1)
Use a single random sample for each gradient:
- Advantage: Updates after every sample — N updates per epoch instead of 1
- Advantage: The noise can help escape sharp local minima (acts like a regularizer)
- Disadvantage: Very noisy gradients — the path toward the optimum is erratic
- Disadvantage: Cannot exploit GPU parallelism (processing one sample at a time wastes GPU compute)
Mini-Batch Gradient Descent: The Sweet Spot
Use a random subset of samples:
where is a random mini-batch of size . This is what everyone actually uses. When people say “SGD” in deep learning, they almost always mean mini-batch SGD.
| Property | Full Batch (B=N) | Mini-Batch (B=32-512) | SGD (B=1) |
|---|---|---|---|
| Updates per epoch | 1 | N/B (many) | N (most) |
| Gradient noise | Zero | Moderate | High |
| GPU utilization | High | High | Low (waste) |
| Memory per step | All data | B samples | 1 sample |
| Convergence path | Smooth | Slightly noisy | Very noisy |
| Generalization | May overfit | Good (noise helps) | Good but slow |
The Mathematics of Mini-Batch Gradients
Let us formalize why mini-batch training works. The true gradient over the full dataset is:
where is the per-sample gradient for sample . When we draw a random mini-batch of size , the mini-batch gradient is:
Property 1: Unbiased Estimation
The mini-batch gradient is an unbiased estimator of the true gradient. That is, its expected value equals the true gradient:
This follows directly from the linearity of expectation. Each sample is drawn uniformly at random from the dataset, so for any randomly chosen sample . Averaging of them still has the same expected value.
In plain English: on average, the mini-batch gradient points in the same direction as the full gradient. Any single mini-batch may be off, but there is no systematic bias.
Property 2: Variance Reduction
The variance of the mini-batch gradient estimator is:
where is the variance of individual per-sample gradients. This is the central limit theorem at work: averaging independent random variables reduces variance by a factor of .
Let us verify this with our running example. At the initial weights , each sample produces a different gradient:
| Sample | Error | Per-sample gradient (dw) | db |
|---|---|---|---|
| 0: x=[1.0, 0.5] | -2.50 | [-5.000, -2.500] | -5.000 |
| 1: x=[2.0, 1.0] | -4.00 | [-16.000, -8.000] | -8.000 |
| 2: x=[0.5, 2.0] | 0.00 | [0.000, 0.000] | 0.000 |
| 3: x=[1.5, 0.5] | -3.50 | [-10.500, -3.500] | -7.000 |
| 4: x=[3.0, 1.5] | -5.50 | [-33.000, -16.500] | -11.000 |
| 5: x=[0.5, 0.5] | -1.50 | [-1.500, -1.500] | -3.000 |
| 6: x=[2.5, 2.0] | -4.00 | [-20.000, -16.000] | -8.000 |
| 7: x=[1.0, 1.5] | -1.50 | [-3.000, -4.500] | -3.000 |
The true gradient (average of all 8) is . The per-sample variance in the first component () is . For different batch sizes:
| Batch Size B | Variance of dw₁ | Std Dev of dw₁ | Relative to B=1 |
|---|---|---|---|
| 1 (SGD) | 112.67 | ±10.61 | 1.00× |
| 2 | 56.34 | ±7.51 | 0.50× |
| 4 | 28.17 | ±5.31 | 0.25× |
| 8 (Full) | 14.08 | ±3.75 | 0.125× |
Doubling the batch size halves the variance. But here is the crucial trade-off: doubling the batch size also halves the number of updates per epoch. With you get 8 noisy updates per epoch. With you get 1 perfect update. In practice, the many noisy updates often lead to faster convergence than one perfect update.
The Noise-Progress Trade-off
Think of the trade-off this way. At each step, your gradient estimate has two components:
The signal always points toward the optimum. The noise is random but shrinks as grows. With enough signal-to-noise ratio (), each step makes progress. The question is: do you take many small noisy steps or few large clean steps? The answer depends on the loss landscape.
Visualizing Batch Size Effects
The interactive visualization below shows gradient descent paths on a 2D loss surface (an elongated valley). All five paths start from the same point. Watch how the batch size controls the amount of noise in the gradient:
Notice several things as you experiment:
- B=1 (red) takes a wild, erratic path but may reach the minimum faster because it gets 8 updates per “epoch” while full-batch gets only 1
- B=N (purple) follows the smoothest, most direct path but takes fewer steps per epoch
- The elongated valley reveals an important effect: in the narrow direction (), even small noise causes overshooting. Larger batches help stability along sensitive dimensions
- Increasing the learning rate amplifies the noise effect — large lr + small batch can diverge
Building Mini-Batch Training from Scratch
Before using PyTorch's DataLoader, let us build the entire mini-batch pipeline from scratch in NumPy. This makes the mechanism completely transparent: you will see exactly how data is shuffled, split into batches, and used to compute gradients.
We use 8 data points from the linear function , starting with weights and bias . The code performs one complete epoch with batch size 2, giving us 4 mini-batch gradient updates.
Look at how the loss fluctuates across batches: 9.13 → 1.49 → 5.39 → 1.00. This is completely normal. Each batch is a different random subset of the data, so the loss it sees differs from the overall loss. The spike at batch 2 happened because those particular samples (with features [0.5, 2.0] and [3.0, 1.5]) pushed the gradient in a locally suboptimal direction.
Key Observation: After just one epoch of mini-batch training (4 gradient updates), the full-dataset loss dropped from 10.66 to 1.00 — a 90.6% reduction. Full-batch gradient descent with the same learning rate would have made only 1 update in the same time, achieving much less progress.
This is the fundamental advantage of mini-batch training: more frequent updates lead to faster learning, even though each individual update is noisier.
PyTorch Dataset and DataLoader
PyTorch provides two classes that handle all the batching machinery we just built by hand:
- Dataset — defines how to access individual samples. You implement two methods:
__len__()returns the total number of samples, and__getitem__(idx)returns one (features, target) pair - DataLoader — wraps a Dataset and handles batching, shuffling, and parallel loading. You iterate over it and get ready-made batches
This is a clean separation of concerns. The Dataset knows about your data format (images? text? tabular?). The DataLoader knows about training logistics (batch size, shuffling, workers). You can swap either independently.
The Dataset Contract
Every PyTorch Dataset must implement two methods:
| Method | What It Returns | When PyTorch Calls It |
|---|---|---|
| __len__(self) | int — total number of samples | Once, to determine how many batches per epoch |
| __getitem__(self, idx) | (features, target) tuple | B times per batch — once per sample |
The DataLoader calls __getitem__ for each index in the current batch, then stacks the results into tensors. If each __getitem__ returns a tuple of (tensor(2,), tensor()), the batch becomes (tensor(B, 2), tensor(B,)).
DataLoader: The Batch Machine
When you write for X_batch, y_batch in loader, the DataLoader does the following each iteration:
- Generate indices: Pick B indices from [0, N-1]. If
shuffle=True, the order is randomized at the start of each epoch - Fetch samples: Call
dataset[idx]for each index. Withnum_workers > 0, this happens in parallel processes - Collate: Stack individual sample tensors into batch tensors. The default collate function handles most cases
- Transfer: With
pin_memory=True, data is placed in pinned (page-locked) memory for faster GPU transfer
Why Shuffling Matters
Without shuffling, the network sees the same batch compositions every epoch. If samples 0-3 happen to be easy and samples 4-7 are hard, the first batches always underestimate the true gradient while the last batches always overestimate it. This creates a systematic oscillation in the weight trajectory.
Shuffling each epoch ensures that batch compositions are random, making each mini-batch gradient an independent, unbiased sample. This is a practical requirement for the theoretical guarantees (unbiasedness, variance reduction) we derived earlier.
shuffle=True for training loaders. For validation and test loaders, set shuffle=False so results are reproducible.Here is the complete PyTorch implementation — the same problem as the NumPy version, but using Dataset and DataLoader:
Compare the two implementations. The PyTorch version replaces our manual index slicing with DataLoader iteration, our manual gradient formulas with loss.backward(), and our manual weight updates with direct tensor operations. The core training rhythm is identical: iterate batches, compute loss, compute gradients, update weights.
Batch Size in Practice
Memory Constraints
The maximum batch size is often dictated by GPU memory. Each sample in a batch requires memory for:
- Activations: The intermediate values at every layer, stored for backpropagation. For a ResNet-50 processing 224×224 images, each sample needs ~100 MB of activation memory
- Gradients: Same size as the activations (one gradient per activation)
- Model parameters: Fixed cost, independent of batch size (ResNet-50: ~100 MB)
A rough formula: . On a 16 GB GPU with a model using 2 GB for parameters, you have ~14 GB for activations. At 100 MB per sample, the maximum batch size is ~140. In practice, you need headroom, so or is typical.
Generalization Effects
Research by Keskar et al. (2017) and Hoffer et al. (2017) showed that large batch sizes can hurt generalization. The intuition: large batches produce smoother gradients that converge to sharp minima, while small batches inject noise that helps find flatter, more generalizable minima.
| Batch Size | Training Loss | Test Accuracy | Character |
|---|---|---|---|
| B = 32-64 | Slightly higher | Best | Noisy, explores broadly |
| B = 256-512 | Lower | Good | Moderate noise |
| B = 4096+ | Lowest | Often worse | Too smooth, sharp minima |
The linear scaling rule (Goyal et al., 2017) provides a practical remedy: when you increase the batch size by a factor of , increase the learning rate by the same factor . This keeps the effective step size roughly constant across batch sizes. Combined with learning rate warmup (gradually increasing lr for the first few epochs), this allows training with very large batches without degrading generalization.
When to Use Gradient Accumulation
If your desired batch size exceeds GPU memory, you can simulate larger batches by accumulating gradients across multiple forward-backward passes before updating:
- Run forward + backward for a mini-batch of size
- Do NOT zero gradients (they accumulate via PyTorch's default behavior)
- Repeat for micro-batches
- Now update weights. Effective batch size:
This is how GPT-3 was trained with effective batch sizes of millions of tokens on GPUs that could only hold a few thousand tokens at a time.
Connection to Modern Systems
The batching concepts in this section scale directly to the largest models in production:
Data Parallelism
In distributed training, the batch is split across multiple GPUs. Each GPU processes a micro-batch, computes local gradients, then all GPUs synchronize gradients via AllReduce. If you have 8 GPUs each processing 32 samples, the effective batch size is . PyTorch's DistributedDataParallel (DDP) handles this automatically — you wrap the model, and DDP inserts gradient synchronization hooks after every backward pass.
The mathematical guarantee: AllReduce computes the average gradient across all GPUs. If each GPU processes an independent mini-batch, the averaged gradient has variance where is the number of GPUs. This is exactly equivalent to a single GPU processing a batch of size .
Flash Attention and Batching
In transformer models, batching has an additional dimension: sequence length . The attention mechanism computes for each sample in the batch. The score matrix has shape where is the number of attention heads. For a batch of 32 sequences, 32 heads, and tokens, this matrix alone uses GB in float16 — more than the memory of most GPUs.
Flash Attention (Dao et al., 2022) solves this by tiling: instead of materializing the full attention matrix, it processes small tiles (typically 128×128) in SRAM (fast on-chip cache), computes partial softmax results, and accumulates the output tile by tile. The outer loop iterates over blocks of keys/values; the inner loop iterates over blocks of queries. Each tile computes a local and uses the online softmax trick (tracking running max and sum) to merge tile results into the exact global softmax.
The result: memory drops from to , and speed improves 2–4× because the attention computation becomes memory-bandwidth-bound (SRAM is 10–20× faster than HBM). From the training loop's perspective, Flash Attention is a drop-in replacement for standard attention — same input, same output, but faster and less memory.
KV-Cache and Batched Inference
During autoregressive inference, each new token generation only needs to attend to all previous tokens. The KV-cache stores the key and value projections of all previous tokens: and per layer per head. Instead of recomputing attention over the entire sequence, the model only computes the new token's query, appends its key/value to the cache, and runs attention against the full cache.
Batched inference serves multiple requests simultaneously by packing their sequences into a single batch. The challenge: different requests are at different generation stages, with different cache lengths. Systems like vLLM use PagedAttention to manage KV-cache memory as variable-length pages (similar to virtual memory in operating systems), allowing efficient batching even when sequences have very different lengths. This relates directly to our DataLoader discussion: just as training batches different samples together, inference systems batch different requests together for GPU utilization.
DataLoader Performance Tips
For production training, DataLoader performance can become a bottleneck. Key settings:
- num_workers > 0: Use parallel processes to load and preprocess data while the GPU trains on the current batch. A common starting point: num_workers = 4 per GPU
- pin_memory = True: Allocates data in page-locked (pinned) memory, enabling faster CPU-to-GPU transfer via DMA. Always use with CUDA GPUs
- prefetch_factor = 2: Each worker prefetches this many batches ahead. Keeps the GPU fed even if occasional samples are slow to load (e.g., large images from disk)
- persistent_workers = True: Keeps worker processes alive between epochs instead of respawning them. Eliminates the overhead of process creation at epoch boundaries
Summary
In this section, we learned why mini-batch training is essential and how it works:
- Mini-batch gradients are unbiased: . On average, they point in the same direction as the full gradient
- Variance decreases with batch size: . Larger batches give smoother gradients
- More updates beat fewer perfect updates: Mini-batching gives updates per epoch vs. 1 for full-batch. The frequent noisy updates converge faster in practice
- Shuffling is mandatory: Without it, batch compositions are correlated across epochs, violating the independence assumption
- PyTorch Dataset + DataLoader separate data access (what is a sample?) from training logistics (batch size, shuffle, parallel loading)
- Batch size is a hyperparameter that affects memory, convergence speed, and generalization. The linear scaling rule helps adapt learning rate to batch size
In the next section, we will put this batching machinery inside a complete training loop with proper loss tracking, gradient clipping, and checkpoint saving.
References
- Robbins, H. & Monro, S. (1951). A Stochastic Approximation Method. Annals of Mathematical Statistics 22(3), 400–407. Origin of stochastic approximation, the mathematical foundation of SGD.
- Bottou, L. (2010). Large-Scale Machine Learning with Stochastic Gradient Descent. COMPSTAT 2010, 177–186.
- Bottou, L., Curtis, F. E. & Nocedal, J. (2018). Optimization Methods for Large-Scale Machine Learning. SIAM Review 60(2), 223–311. Rigorous treatment of mini-batch convergence and batch-size tradeoffs.
- Keskar, N. S., Mudigere, D., Nocedal, J., Smelyanskiy, M. & Tang, P. T. P. (2017). On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima. ICLR 2017 / arXiv:1609.04836.
- Smith, S. L., Kindermans, P.-J., Ying, C. & Le, Q. V. (2018). Don't Decay the Learning Rate, Increase the Batch Size. ICLR 2018 / arXiv:1711.00489.
- Goyal, P. et al. (2017). Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour. arXiv:1706.02677. Linear scaling rule for batch size and learning rate.
- He, H. & Garcia, E. A. (2009). Learning from Imbalanced Data. IEEE Transactions on Knowledge and Data Engineering 21(9), 1263–1284.
- Paszke, A. et al. (2019). PyTorch: An Imperative Style, High-Performance Deep Learning Library. NeurIPS 2019.
- PyTorch Documentation. torch.utils.data — DataLoader, Dataset, Sampler. https://pytorch.org/docs/stable/data.html