Chapter 8
10 min read
Section 33 of 121

BatchNorm and Dropout for Stability

CNN Feature Extractor

Two Stability Knobs

Stacking three Conv1D layers without regularisation is a recipe for unstable training: activations explode or vanish, and the model overfits to noise patterns in individual channels. Two techniques applied at every conv block do most of the work.

TechniqueWhat it fixesCost
BatchNorm1dActivation magnitude drift; slow training+2 params per channel; tiny FLOPs
DropoutCo-adaptation; brittle channelsZero extra params; ~5% throughput hit

BatchNorm: Per-Channel Whitening

BatchNorm1d normalises each channel's activations to mean 0 and unit variance, then applies a learnable per-channel scale and shift:

x^=xμσ2+ε,y=γx^+β.\hat{x} = \frac{x - \mu}{\sqrt{\sigma^2 + \varepsilon}}, \qquad y = \gamma\,\hat{x} + \beta.

Two regimes:

ModeStatistics usedUpdates
TrainingCurrent batch's mean / varUpdates running averages used at eval
EvaluationRunning mean / varNo updates; deterministic per input

Dropout: Randomly Forgetting

Dropout zeroes each activation with probability pp during training, then scales the survivors by 1/(1p)1/(1-p) so expected output magnitude is unchanged. We use p=0.15p = 0.15 in the conv blocks - small because the input is a 30-cycle window. Heavier dropout (p=0.3p = 0.3) lives in the FC stack at the end of the backbone (Chapter 11).

BatchNorm and Dropout do DIFFERENT things. BN stabilises the activation distribution across batches. Dropout breaks co-adaptation between channels. Removing either typically degrades training.

Python: BN and Dropout in 10 Lines

BatchNorm + inverted dropout from scratch
🐍bn_dropout_numpy.py
1import numpy as np

Standard alias.

4def batch_norm_1d(x, gamma, beta, eps=1e-5):

Pure-NumPy BatchNorm1d in training mode.

6mu = x.mean(axis=(0, 2), keepdims=True)

Per-channel mean over BATCH and TIME. Shape (1, C, 1).

7var = x.var(axis=(0, 2), keepdims=True)

Per-channel variance.

8x_hat = (x - mu) / np.sqrt(var + eps)

Whiten: zero mean, unit variance per channel.

9return x_hat * gamma.reshape(1, -1, 1) + beta.reshape(1, -1, 1)

Re-scale and re-shift via learnable per-channel gamma and beta.

13def dropout(x, p, training=True):

Inverted dropout - the standard since AlexNet (2012).

14if not training or p == 0: return x

Eval mode = identity. Zero p = no-op.

16mask = (np.random.random_sample(x.shape) > p).astype(np.float32)

Bernoulli mask: 1 with prob (1-p), 0 with prob p.

17return x * mask / (1 - p)

Apply mask AND scale up by 1/(1-p) - keeps expected magnitude unchanged.

22np.random.seed(0)

Determinism.

23x = np.random.randn(4, 64, 30).astype(np.float32) * 100 + 50

Fake activation tensor with mean 50, std 100.

27bn = batch_norm_1d(x, gamma, beta)

Apply BN.

28drop = dropout(bn, p=0.15, training=True)

Apply dropout on the BN output.

30print(f"x mean / std : {x.mean():.3f} / {x.std():.3f}")

Pre-BN: mean ~50, std ~100.

EXECUTION STATE
Output = x mean / std : 49.987 / 99.954
31print(f"bn mean / std : {bn.mean():.3f} / {bn.std():.3f}")

Post-BN: mean ~0, std ~1.

EXECUTION STATE
Output = bn mean / std : -0.000 / 1.000
32print(f"drop mean / std : ...")

Post-dropout: similar to BN, slightly higher std due to scale-up.

EXECUTION STATE
Output (representative) = drop mean / std : 0.005 / 1.062
33print(f"drop non-zero % : {100 * (drop != 0).mean():.1f}%")

About 85% of cells survive.

EXECUTION STATE
Output = drop non-zero % : 85.0%
14 lines without explanation
1import numpy as np
2
3# ----- BatchNorm1d (training mode) -----
4def batch_norm_1d(x: np.ndarray, gamma, beta, eps=1e-5):
5    """x: (B, C, T). Normalise per channel over (B, T)."""
6    mu  = x.mean(axis=(0, 2), keepdims=True)
7    var = x.var (axis=(0, 2), keepdims=True)
8    x_hat = (x - mu) / np.sqrt(var + eps)
9    return x_hat * gamma.reshape(1, -1, 1) + beta.reshape(1, -1, 1)
10
11
12# ----- Inverted dropout (training mode) -----
13def dropout(x: np.ndarray, p: float, training: bool = True):
14    if not training or p == 0:
15        return x
16    mask = (np.random.random_sample(x.shape) > p).astype(np.float32)
17    return x * mask / (1 - p)
18
19
20# ----- Run on a tiny batch -----
21np.random.seed(0)
22x = np.random.randn(4, 64, 30).astype(np.float32) * 100 + 50
23gamma = np.ones (64, dtype=np.float32)
24beta  = np.zeros(64, dtype=np.float32)
25
26bn   = batch_norm_1d(x, gamma, beta)
27drop = dropout(bn, p=0.15, training=True)
28
29print(f"x      mean / std : {x.mean():.3f} / {x.std():.3f}")
30print(f"bn     mean / std : {bn.mean():.3f} / {bn.std():.3f}")
31print(f"drop   mean / std : {drop.mean():.3f} / {drop.std():.3f}")
32print(f"drop   non-zero %  : {100 * (drop != 0).mean():.1f}%")

PyTorch: nn.BatchNorm1d and nn.Dropout

Train / eval mode behaviour
🐍bn_dropout_torch.py
1import torch, torch.nn as nn

Tensors + layers.

2torch.manual_seed(0)

Determinism.

4bn = nn.BatchNorm1d(64)

BN1d for 64 channels. Has 4 fields: weight (gamma) + bias (beta) learnable, running_mean + running_var as buffers.

5drop = nn.Dropout(0.15)

Drops 15% of activations during training; identity at eval.

7x = torch.randn(4, 64, 30) * 100 + 50

Fake activations.

10bn.train(); drop.train()

Set both to training mode.

11y_train = drop(bn(x))

Train-mode forward.

12print("train:", y_train.mean().item(), y_train.std().item())

Approximately 0 mean, ~1 std.

EXECUTION STATE
Output (representative) = train: 0.0021 1.082
15bn.eval(); drop.eval()

Switch to eval - BN uses running stats, dropout off.

16y_eval = drop(bn(x))

Eval-mode forward.

17print("eval :", y_eval.mean().item(), y_eval.std().item())

Different mean/std because BN used running stats (defaults near 0/1 before training updates them) and dropout is off. After many training batches running stats stabilise.

EXECUTION STATE
Output (representative) = eval : 50.05 99.95
20print("dropout zeros (train):", int((y_train == 0).sum()))

About 15% of cells zeroed.

EXECUTION STATE
Output = dropout zeros (train): ~1,150
21print("dropout zeros (eval) :", int((y_eval == 0).sum()))

Eval mode = identity = no zeros.

EXECUTION STATE
Output = dropout zeros (eval) : 0
8 lines without explanation
1import torch, torch.nn as nn
2torch.manual_seed(0)
3
4bn   = nn.BatchNorm1d(64)
5drop = nn.Dropout(0.15)
6
7x = torch.randn(4, 64, 30) * 100 + 50
8
9# Training mode
10bn.train(); drop.train()
11y_train = drop(bn(x))
12print("train:", y_train.mean().item(), y_train.std().item())
13
14# Evaluation mode
15bn.eval(); drop.eval()
16y_eval = drop(bn(x))
17print("eval :", y_eval.mean().item(), y_eval.std().item())
18
19# Same shape, different behaviour
20print("dropout zeros (train):", int((y_train == 0).sum()))
21print("dropout zeros (eval) :", int((y_eval  == 0).sum()))

The .train() / .eval() Distinction

BatchNorm and Dropout are the two layers in the book where this distinction matters. Forgetting to call model.eval() before validation is among the most common bugs in PyTorch code.

StateBatchNorm usesDropout
.train()Current batch statsActive (drops + scales)
.eval()Running statsIdentity (passes through)
Production idiom. Wrap eval in awith torch.no_grad(): block AND set model.eval() before scoring.

Three BN/Dropout Pitfalls

Pitfall 1: Forgetting model.eval(). The model keeps using current-batch BN stats and keeps dropping activations during validation - reported metrics are noisier than reality.
Pitfall 2: BN with batch_size = 1. BN computes statistics ACROSS the batch axis. With a single sample, the variance is zero. Switch to GroupNorm or LayerNorm for tiny-batch settings.
Pitfall 3: Saving / loading BN buffers. running_mean and running_var are NOT parameters - they are buffers. state_dict handles them; manual parameter iteration loses them and eval-mode breaks.
The point. BN keeps activations on a well-conditioned scale; dropout keeps the network from over-relying on any single channel. Both are mode-dependent.

Takeaway

  • BN: per-channel whitening with learnable gamma / beta. Train uses batch stats; eval uses running stats.
  • Dropout: zero out 15% of activations during training. Inverted variant scales survivors so eval is identity.
  • Always toggle .train() / .eval() correctly. Otherwise BN drifts and dropout corrupts validation.
  • BN buffers travel with state_dict. Manual serialisation can lose them.
Loading comments...