Chapter 9
22 min read
Section 57 of 178

Weight Initialization

Training Neural Networks

Learning Objectives

By the end of this section, you will be able to:

  1. Understand why initialization matters: Grasp how poor initialization causes vanishing/exploding gradients that prevent deep networks from learning
  2. Derive Xavier initialization: Mathematically derive the variance-preserving initialization for linear and tanh activations
  3. Derive He initialization: Understand why ReLU networks need different initialization and derive the formula
  4. Implement initialization in PyTorch: Apply proper initialization to your networks using PyTorch's built-in functions
  5. Choose the right strategy: Select appropriate initialization based on your network architecture and activation functions
Why This Matters: Before Xavier and He initialization were discovered, training deep networks was extremely difficult. These insights about variance preservation enabled the training of networks with hundreds of layers. Understanding initialization is essential for training modern deep learning models.

The Big Picture

The Training Disaster

Imagine spending weeks designing a sophisticated 50-layer neural network, only to find that training fails completely—the loss doesn't decrease, gradients are either zero or infinity, and no learning occurs. This was a common experience in the early days of deep learning, and the culprit was often something as simple as how the weights were initialized.

Weight initialization might seem like a minor implementation detail, but it's actually one of the most critical factors determining whether a deep network can be trained at all. The values we assign to weights before training begins set the stage for everything that follows.

The Core Insight

The fundamental insight is this: as signals (activations during forward pass, gradients during backward pass) flow through a deep network, their magnitudes must remain approximately constant. If signals shrink at each layer, they vanish by the time they reach deep layers. If they grow, they explode into numerical overflow.

Healthy training    Var(layer l)Var(layer l+1)\text{Healthy training} \iff \text{Var}(\text{layer } l) \approx \text{Var}(\text{layer } l+1)

The key breakthrough was realizing that we can control this variance flow by carefully choosing the variance of the initial weights.

Historical Context

Two seminal papers solved this problem:

YearAuthorsMethodKey Insight
2010Xavier Glorot & Yoshua BengioXavier/Glorot InitPreserve variance for linear/tanh activations
2015Kaiming He et al.He/Kaiming InitAccount for ReLU halving the variance

These discoveries, along with batch normalization and residual connections, enabled the training of very deep networks that power modern AI.


The Symmetry Breaking Problem

Why Not Initialize All Weights to the Same Value?

Consider initializing all weights to zero (or any constant). What happens?

In the forward pass, every neuron in a layer receives the same weighted sum of inputs (since all weights are identical). They all produce the same output. During backpropagation, they all receive the same gradient. And after the weight update, they all have the same new weight value.

w1=w2==wn    h1=h2==hnw_1 = w_2 = \ldots = w_n \implies h_1 = h_2 = \ldots = h_n

This symmetry persists throughout training. The network effectively collapses to having just one neuron per layer, wasting all the extra capacity. No matter how long you train, the neurons never learn different features.

Never Initialize All Weights to the Same Value

This includes zero! Always use random initialization to break symmetry. Each neuron must start with different weights to learn different features.

Random Initialization: The First Attempt

The obvious solution is random initialization. But what distribution should we sample from? A naive approach is to use standard normal: WijN(0,1)W_{ij} \sim \mathcal{N}(0, 1).

This breaks symmetry, but creates a new problem. Consider a fully-connected layer with nn inputs:

z=i=1nwixiz = \sum_{i=1}^{n} w_i x_i

If Var(wi)=1\text{Var}(w_i) = 1 and Var(xi)=1\text{Var}(x_i) = 1 (with everything independent and zero-mean):

Var(z)=i=1nVar(wixi)=i=1nVar(wi)Var(xi)=n\text{Var}(z) = \sum_{i=1}^{n} \text{Var}(w_i x_i) = \sum_{i=1}^{n} \text{Var}(w_i) \cdot \text{Var}(x_i) = n

The variance grows by a factor of nn at each layer! In a network with 256 neurons per layer, the variance explodes to astronomical values within just a few layers.

Quick Check

If you initialize weights from N(0, 1) and have 512 neurons per layer, what happens to the variance after 10 layers?


Variance Flow Through Networks

The Mathematical Framework

Let's carefully analyze how variance propagates through a single layer. Consider a fully-connected layer with:

  • ninn_{\text{in}} input neurons (fan-in)
  • noutn_{\text{out}} output neurons (fan-out)
  • Weights WijW_{ij} drawn i.i.d. with mean 0 and variance σ2\sigma^2
  • Inputs xix_i with mean 0 and variance vinv_{\text{in}}

For a single output neuron (before activation):

zj=i=1ninwijxiz_j = \sum_{i=1}^{n_{\text{in}}} w_{ij} x_i

Using the property that the variance of a sum of independent terms is the sum of variances:

Var(zj)=i=1ninVar(wij)Var(xi)=ninσ2vin\text{Var}(z_j) = \sum_{i=1}^{n_{\text{in}}} \text{Var}(w_{ij}) \cdot \text{Var}(x_i) = n_{\text{in}} \cdot \sigma^2 \cdot v_{\text{in}}

The Variance Preservation Condition

For stable forward propagation, we want the output variance to equal the input variance:

Var(zj)=vin\text{Var}(z_j) = v_{\text{in}}

This gives us the condition:

ninσ2vin=vinn_{\text{in}} \cdot \sigma^2 \cdot v_{\text{in}} = v_{\text{in}}
σ2=1nin\boxed{\sigma^2 = \frac{1}{n_{\text{in}}}}

This is the key insight: the variance of the weights should be inversely proportional to the number of input connections.

Backward Pass: Gradient Variance

The same analysis applies to gradients flowing backward. During backpropagation, the gradient with respect to an input is:

Lxi=j=1noutwijLzj\frac{\partial L}{\partial x_i} = \sum_{j=1}^{n_{\text{out}}} w_{ij} \frac{\partial L}{\partial z_j}

For gradients to maintain variance:

noutσ2=1    σ2=1noutn_{\text{out}} \cdot \sigma^2 = 1 \implies \sigma^2 = \frac{1}{n_{\text{out}}}

We have a conflict! Forward pass wants σ2=1/nin\sigma^2 = 1/n_{\text{in}}, but backward pass wants σ2=1/nout\sigma^2 = 1/n_{\text{out}}.


Interactive: Variance Propagation

Experiment with different initialization strategies and see how variance changes as signals propagate through a deep network. Observe how naive initialization causes variance to explode or vanish, while proper initialization maintains stability:

Variance Flow Through Layers

See how different initializations affect signal propagation

Ideal10-1010-51001051010Layer 1Layer 6Layer 10Network DepthVariance (log scale)
Naive (Var=1)
Xavier (Var=1/n)
He (Var=2/n)

Naive Initialization

Variance can vanish or explode exponentially with depth. Unusable for deep networks.

Xavier Initialization

Maintains variance for linear/tanh activations. Slightly suboptimal for ReLU.

He Initialization

Accounts for ReLU killing half the neurons. Best choice for ReLU networks.

Key Observations

  1. With naive initialization (Var=1), variance grows exponentially—leading to numerical overflow
  2. Xavier initialization maintains variance for linear activations, but slightly underperforms for ReLU
  3. He initialization is optimal for ReLU, keeping variance stable across all layers
  4. Increasing network depth makes proper initialization even more critical

Xavier (Glorot) Initialization

Resolving the Conflict

Xavier Glorot and Yoshua Bengio proposed a compromise: use the average of the forward and backward requirements:

Var(W)=2nin+nout\text{Var}(W) = \frac{2}{n_{\text{in}} + n_{\text{out}}}

This ensures that variance is approximately preserved in both directions. The name "Xavier" comes from Xavier Glorot's first name.

Xavier Normal Initialization

Sample weights from a normal distribution with the computed variance:

WijN(0,2nin+nout)W_{ij} \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{\text{in}} + n_{\text{out}}}}\right)

Standard Deviation vs Variance

PyTorch's xavier_normal_ function takes std, not variance. The standard deviation is the square root of variance: σ=2/(nin+nout)\sigma = \sqrt{2/(n_{\text{in}} + n_{\text{out}})}.

Xavier Uniform Initialization

Alternatively, sample from a uniform distribution with the same variance:

WijU(a,a)wherea=6nin+noutW_{ij} \sim \mathcal{U}(-a, a) \quad \text{where} \quad a = \sqrt{\frac{6}{n_{\text{in}} + n_{\text{out}}}}

The limit aa is derived from the uniform distribution's variance formula: Var(U(a,a))=a2/3\text{Var}(\mathcal{U}(-a, a)) = a^2/3. Setting this equal to 2/(nin+nout)2/(n_{\text{in}} + n_{\text{out}}) gives a=6/(nin+nout)a = \sqrt{6/(n_{\text{in}} + n_{\text{out}})}.

When to Use Xavier

Xavier initialization was derived assuming linear activations or activations symmetric around zero like tanh. It works well for:

  • Networks with tanh or sigmoid activations
  • Linear layers without activation (like the final output layer)
  • Attention mechanisms and transformers (which often use linear projections)

Not Optimal for ReLU

Xavier initialization underestimates the variance needed for ReLU networks. Since ReLU sets negative values to zero, it effectively halves the variance at each layer. Use He initialization for ReLU instead.

He (Kaiming) Initialization

The ReLU Problem

Consider ReLU activation: ReLU(z)=max(0,z)\text{ReLU}(z) = \max(0, z). For a symmetric input distribution centered at zero, ReLU outputs zero for half the inputs. This halves the variance:

Var(ReLU(z))=12Var(z)\text{Var}(\text{ReLU}(z)) = \frac{1}{2} \text{Var}(z)

This derivation assumes zz is symmetric around zero. For the forward pass through a ReLU layer:

vout=12ninσ2vinv_{\text{out}} = \frac{1}{2} \cdot n_{\text{in}} \cdot \sigma^2 \cdot v_{\text{in}}

To preserve variance (vout=vinv_{\text{out}} = v_{\text{in}}):

σ2=2nin\boxed{\sigma^2 = \frac{2}{n_{\text{in}}}}

This is exactly twice the variance of Xavier initialization (considering only fan-in). Kaiming He and colleagues derived this in their 2015 paper on training very deep networks.

He Normal Initialization

WijN(0,2nin)W_{ij} \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{\text{in}}}}\right)

He Uniform Initialization

WijU(a,a)wherea=6ninW_{ij} \sim \mathcal{U}(-a, a) \quad \text{where} \quad a = \sqrt{\frac{6}{n_{\text{in}}}}

Fan-in vs Fan-out Mode

PyTorch allows you to choose whether to use ninn_{\text{in}} (fan-in) or noutn_{\text{out}} (fan-out):

ModeFormulaBest For
fan_inVar = 2/n_inPreserving forward pass variance (default)
fan_outVar = 2/n_outPreserving backward pass variance

Rule of Thumb

Use fan_in mode (the default) for most cases. Use fan_out only if you have specific reasons to prioritize gradient stability over activation stability.

Quick Check

For a layer with 512 input neurons and ReLU activation, what should be the weight variance using He initialization?


Interactive: Weight Distributions

Explore how different initialization strategies create different weight distributions. Adjust the fan-in and fan-out values to see how they affect the distribution's spread:

Weight Distribution Visualizer

See how different initialization strategies create different weight distributions

Weight ValueFrequency0-0.246-0.0080.229
Mean
0.0004
Variance
0.0039
Std Dev
0.0626
Min
-0.2455
Max
0.2292

Xavier Normal

W ~ N(0, 2/(n_in+n_out))

Theoretical Std
0.0625

Observations

  • Larger layers = smaller weights: As fan-in increases, the weights become more concentrated around zero
  • Xavier vs He: He initialization has slightly wider spread than Xavier to compensate for ReLU
  • Uniform vs Normal: Both achieve the same variance but with different distribution shapes

Interactive: Signal Propagation

Visualize how activations and gradients flow through a deep network with different initialization strategies. Watch how poor initialization causes signals to vanish or explode:

Signal Propagation Through Layers

Visualize how activations flow forward through a deep network

Activation Magnitude (Forward)
Healthy
Warning
Vanishing/Exploding
In
1
2
3
4
5
6
Out
LayerMeanStd DevMax AbsStatus
Input5.23e-15.64e-11.84e+0Healthy
Hidden 15.31e-16.58e-13.05e+0Healthy
Hidden 20.00e+00.00e+00.00e+0Vanishing
Hidden 30.00e+00.00e+00.00e+0Vanishing
Hidden 40.00e+00.00e+00.00e+0Vanishing
... 3 more layers ...

He/Kaiming Initialization

Optimal for ReLU: Compensates for the fact that ReLU sets half of activations to zero. Maintains healthy signal flow in deep networks.

What to Look For

  • All Zeros: All neurons output identical values—no learning possible
  • Too Small: Activations shrink toward zero in deep layers (vanishing)
  • Too Large: Activations grow exponentially (exploding)
  • Xavier/He: Healthy signal flow with consistent magnitudes across layers

PyTorch Implementation

Built-in Initialization Functions

PyTorch provides initialization functions in torch.nn.init:

PyTorch Initialization Functions
🐍init_functions.py
9In-place Modification

Functions ending in underscore (_) modify tensors in-place. They return the tensor for chaining but modify the original.

EXAMPLE
init.xavier_uniform_(layer.weight)
13Mode Parameter

'fan_in' preserves forward pass variance (default). 'fan_out' preserves backward pass variance. Usually stick with 'fan_in'.

14Nonlinearity Parameter

Tells PyTorch which activation function follows this layer. Affects the gain (scaling factor) applied to the initialization.

EXAMPLE
'relu', 'leaky_relu', 'tanh', 'sigmoid'
17Zero Biases

Biases are typically initialized to zero. They don't have the symmetry problem since each neuron has only one bias.

20 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.init as init
4
5# Create a linear layer
6layer = nn.Linear(512, 256)
7
8# Xavier initialization
9init.xavier_uniform_(layer.weight)  # Uniform version
10init.xavier_normal_(layer.weight)   # Normal version
11
12# He (Kaiming) initialization
13init.kaiming_uniform_(layer.weight, mode='fan_in', nonlinearity='relu')
14init.kaiming_normal_(layer.weight, mode='fan_in', nonlinearity='relu')
15
16# Zero-initialize biases (common practice)
17init.zeros_(layer.bias)
18
19# Constant initialization
20init.constant_(layer.weight, 0.01)  # Not recommended!
21
22# Check the resulting variance
23print(f"Weight variance: {layer.weight.var().item():.6f}")
24print(f"Expected (He): {2/512:.6f}")

Custom Initialization for a Model

Custom Model Initialization
🐍custom_init.py
19Initialization Method

Define a private method to handle initialization. Call it at the end of __init__ to ensure all layers are created first.

20Iterate Over Modules

self.modules() returns all submodules recursively. This includes nested Sequential containers and individual layers.

21Type Checking

Only initialize Linear layers. You'd add similar blocks for Conv2d, BatchNorm, etc. with their appropriate initialization.

37 lines without explanation
1import torch.nn as nn
2import torch.nn.init as init
3
4class MLP(nn.Module):
5    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
6        super().__init__()
7
8        self.layers = nn.Sequential(
9            nn.Linear(input_dim, hidden_dim),
10            nn.ReLU(),
11            nn.Linear(hidden_dim, hidden_dim),
12            nn.ReLU(),
13            nn.Linear(hidden_dim, output_dim),
14        )
15
16        # Apply custom initialization
17        self._init_weights()
18
19    def _init_weights(self):
20        for module in self.modules():
21            if isinstance(module, nn.Linear):
22                # He initialization for ReLU layers
23                init.kaiming_normal_(
24                    module.weight,
25                    mode='fan_in',
26                    nonlinearity='relu'
27                )
28                if module.bias is not None:
29                    init.zeros_(module.bias)
30
31    def forward(self, x):
32        return self.layers(x)
33
34# Usage
35model = MLP(784, 256, 10)
36
37# Verify initialization
38for name, param in model.named_parameters():
39    if 'weight' in name:
40        print(f"{name}: mean={param.mean():.4f}, var={param.var():.4f}")

Initialization with apply()

PyTorch also provides apply() for recursively applying a function to all modules:

🐍apply_init.py
1def init_weights(m):
2    """Initialize weights for different layer types."""
3    if isinstance(m, nn.Linear):
4        init.kaiming_normal_(m.weight, nonlinearity='relu')
5        if m.bias is not None:
6            init.zeros_(m.bias)
7    elif isinstance(m, nn.Conv2d):
8        init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
9        if m.bias is not None:
10            init.zeros_(m.bias)
11    elif isinstance(m, nn.BatchNorm2d):
12        init.ones_(m.weight)
13        init.zeros_(m.bias)
14
15# Apply to model
16model = MyNetwork()
17model.apply(init_weights)

PyTorch Default Initialization

PyTorch's default initialization for nn.Linear is already pretty good—it uses a variant of Kaiming uniform initialization. But explicitly setting initialization gives you control and makes your code clearer.

Practical Guidelines

Choosing the Right Initialization

Layer TypeActivationRecommended Initialization
Linear/DenseReLU, LeakyReLU, PReLUHe (Kaiming)
Linear/DenseTanh, Sigmoid, LinearXavier (Glorot)
Conv2dReLU, LeakyReLUHe (Kaiming)
BatchNorm-weight=1, bias=0
LayerNorm-weight=1, bias=0
Embedding-Normal(0, 1) or Xavier
LSTM/GRUVariousOrthogonal or Xavier
Output layerSoftmax/LinearXavier or smaller

Special Cases

Residual Networks

For residual networks with skip connections, it's common to initialize the last layer of each residual block to zero. This makes the network initially behave like a shallower network:

🐍resnet_init.py
1class ResidualBlock(nn.Module):
2    def __init__(self, dim):
3        super().__init__()
4        self.layers = nn.Sequential(
5            nn.Linear(dim, dim),
6            nn.ReLU(),
7            nn.Linear(dim, dim),  # Initialize this to zero
8        )
9        # Zero-initialize the final layer
10        init.zeros_(self.layers[-1].weight)
11        init.zeros_(self.layers[-1].bias)
12
13    def forward(self, x):
14        return x + self.layers(x)  # Skip connection

Transformers

Modern transformers often scale initialization by layer depth:

🐍transformer_init.py
1# GPT-style initialization
2# Scale residual layers by 1/sqrt(num_layers)
3scale = 1 / math.sqrt(num_layers)
4init.normal_(layer.weight, mean=0.0, std=0.02 * scale)

Output Layers

For classification tasks, initializing the output layer with smaller weights often helps:

🐍output_init.py
1# Smaller initialization for output layer
2init.normal_(output_layer.weight, mean=0.0, std=0.01)
3init.zeros_(output_layer.bias)

Related Topics

  • Chapter 8 Section 6: Gradient Flow Analysis - Deep dive into vanishing/exploding gradients and why initialization matters for gradient flow
  • Section 5: Normalization Layers - Batch and Layer Normalization also help stabilize activations and gradients

Summary

Weight initialization is crucial for training deep neural networks. Here are the key takeaways:

ConceptKey Point
Symmetry BreakingRandom initialization is essential; identical weights lead to identical neurons
Variance PreservationSignals should maintain consistent magnitude across layers
Xavier/GlorotVar(W) = 2/(n_in + n_out), optimal for linear/tanh activations
He/KaimingVar(W) = 2/n_in, optimal for ReLU activations (doubles Xavier variance)
Fan-in vs Fan-outfan_in preserves forward variance (default), fan_out preserves backward
BiasesUsually initialized to zero (no symmetry issue)

Quick Reference: Initialization Formulas

📝formulas.txt
1Xavier Normal:   W ~ N(0, sqrt(2 / (n_in + n_out)))
2Xavier Uniform:  W ~ U(-sqrt(6/(n_in+n_out)), sqrt(6/(n_in+n_out)))
3
4He Normal:       W ~ N(0, sqrt(2 / n_in))
5He Uniform:      W ~ U(-sqrt(6/n_in), sqrt(6/n_in))

Knowledge Check

Test your understanding of weight initialization concepts:

Knowledge Check

Question 1 of 5

Score: 0/5

Why can't we initialize all weights to zero?


Exercises

Conceptual Questions

  1. Explain why initializing all weights to small constants (e.g., 0.01) is better than zeros but still problematic for deep networks.
  2. A layer has 1024 input neurons and 512 output neurons. Calculate the variance and uniform initialization bounds for both Xavier and He initialization.
  3. Why might you choose fan_out mode for He initialization in a generative model?
  4. How does batch normalization reduce the dependence on initialization?

Solution Hints

  1. Q1: Small constants break symmetry but can still cause vanishing gradients in deep networks since they don't scale with layer size.
  2. Q2: Xavier: Var = 2/(1024+512) = 0.0013, bounds = ±0.089. He: Var = 2/1024 = 0.00195, bounds = ±0.076.
  3. Q3: In generative models, gradient flow from loss to input is critical. fan_out optimizes for stable backward propagation.
  4. Q4: BatchNorm normalizes activations to zero mean and unit variance, partially correcting for poor initialization.

Coding Exercises

  1. Implement from scratch: Write your own xavier_normal and kaiming_normal functions without using PyTorch's built-in versions. Verify they produce the same variance.
  2. Variance tracking: Create a neural network that records the variance of activations at each layer during a forward pass. Compare different initializations.
  3. Training comparison: Train the same model architecture on MNIST with zeros, random N(0,1), Xavier, and He initialization. Plot loss curves and final accuracy for each.
  4. Gradient analysis: Extend exercise 2 to also track gradient variance during backward pass. Verify that He initialization preserves gradient magnitudes better than Xavier for ReLU networks.

Exercise Code Template

🐍variance_tracking.py
1class VarianceTracker(nn.Module):
2    def __init__(self, layers):
3        super().__init__()
4        self.layers = nn.ModuleList(layers)
5        self.activation_vars = []
6
7    def forward(self, x):
8        self.activation_vars = [x.var().item()]
9        for layer in self.layers:
10            x = layer(x)
11            self.activation_vars.append(x.var().item())
12        return x

In the next section, we'll explore regularization techniques—methods like dropout, weight decay, and data augmentation that prevent overfitting and improve generalization.