Chapter 13
15 min read
Section 63 of 104

Exponential Moving Average (EMA)

Training Enhancements

Learning Objectives

By the end of this section, you will:

  1. Understand exponential moving average of model weights
  2. Derive the EMA update formula and its properties
  3. Explain why EMA improves generalization
  4. Configure the decay parameter for different scenarios
  5. Implement EMA with shadow weights and restore functionality
Why This Matters: Neural network training is inherently noisy—each mini-batch update pushes weights in a slightly different direction. Exponential Moving Average (EMA) maintains a smoothed version of the weights that often generalizes better than the final checkpoint. This simple technique can reduce RMSE by 2-5% with almost no additional cost.

What is Exponential Moving Average?

EMA maintains a "shadow" copy of model weights that is updated as a weighted average of current and past weights.

The Core Idea

During training, weights oscillate around optimal values due to stochastic gradient noise. EMA smooths these oscillations by keeping a running average:

θEMA(t)=βθEMA(t1)+(1β)θ(t)\theta_{\text{EMA}}^{(t)} = \beta \cdot \theta_{\text{EMA}}^{(t-1)} + (1 - \beta) \cdot \theta^{(t)}

Where:

  • θ(t)\theta^{(t)}: Current model weights at step t
  • θEMA(t)\theta_{\text{EMA}}^{(t)}: EMA weights at step t
  • β\beta: Decay parameter (typically 0.999)

Intuition: Averaging Over a Window

The decay parameter β controls how many recent updates contribute to the EMA. A higher β means a longer effective window:

βEffective WindowBehavior
0.9~10 stepsFast adaptation, less smoothing
0.99~100 stepsModerate smoothing
0.999~1000 stepsStrong smoothing (recommended)
0.9999~10000 stepsVery slow adaptation

AMNL Uses β = 0.999

For our RUL prediction model, we use β = 0.999, which averages over approximately 1000 steps. This provides strong noise reduction while still adapting to significant weight changes.


EMA Mathematics

Understanding the mathematical properties of EMA reveals why it works.

Recursive Expansion

Effective Window Size

The "effective window" is approximately how many steps contribute significantly:

Neff11βN_{\text{eff}} \approx \frac{1}{1 - \beta}

Why EMA Works

EMA improves generalization through several mechanisms.

Noise Reduction

Mini-batch training introduces variance in gradient estimates. Each update moves weights in a slightly wrong direction:

θ(t+1)=θ(t)η(Ltrue+ϵt)\theta^{(t+1)} = \theta^{(t)} - \eta \left(\nabla \mathcal{L}_{\text{true}} + \epsilon_t\right)

Where ϵt\epsilon_t is zero-mean noise from mini-batch sampling. EMA averages out this noise:

E[θEMA]θ(converges to optimal)\mathbb{E}[\theta_{\text{EMA}}] \approx \theta^* \quad \text{(converges to optimal)}

Implicit Regularization

EMA has an implicit regularization effect. It prefers solutions that are stable across many training steps—solutions that generalize well tend to be stable.

Weight TypeEMA EffectResult
Noisy/unstableHeavily smoothedRegularized
Stable/convergedMinimally affectedPreserved
OscillatingAveraged outCentered

Loss Landscape Smoothing

EMA weights tend to sit in flatter regions of the loss landscape, which correlate with better generalization:

📝text
1Loss Landscape View:
2
3Loss
4  │      ╭──╮
5  │     ╱    ╲
6  │    ╱  ×   ╲      ← Raw weights (sharp minimum)
7  │   ╱        ╲
8  │──╱    ⊙     ╲──  ← EMA weights (flat region)
9  │ ╱            ╲
10  └─────────────────→ Weight space
11
12× = oscillating raw weights
13⊙ = smoothed EMA weights in flat basin

Better Generalization

Flat minima generalize better because small weight perturbations (from distribution shift) cause smaller loss increases. EMA naturally finds these flat regions by averaging.


Implementation

Our research implementation provides a clean, efficient EMA class with shadow weights for evaluation.

AMNL Research Implementation

ExponentialMovingAverage Class
🐍enhanced_train_nasa_cmapss_sota_v7.py
1Class Definition

Lightweight EMA implementation that tracks shadow copies of all trainable parameters.

4Decay Parameter

Default decay of 0.999 averages over ~1000 steps. Higher values give more smoothing but slower adaptation.

5Shadow Weights

Dictionary storing EMA versions of each parameter. Keys are parameter names from model.named_parameters().

6Backup Storage

Stores original weights when apply_shadow is called, enabling restoration for continued training.

10Initialize Shadows

Clone initial weights to shadow. Only tracks parameters with requires_grad=True.

13Update Method

Called after each optimizer.step() to update shadow weights with the EMA formula.

16EMA Formula

shadow = (1-decay)*current + decay*shadow. Note: equivalent to decay*shadow + (1-decay)*current.

EXAMPLE
0.001*new + 0.999*old = slow adaptation to new weights
18Apply Shadow

Temporarily replace model weights with EMA weights for evaluation. Backs up originals first.

24Restore Weights

Restore original training weights after evaluation. Essential for continuing training.

18 lines without explanation
1class ExponentialMovingAverage:
2    """EMA weights for training stability"""
3
4    def __init__(self, model, decay=0.999):
5        self.decay = decay
6        self.shadow = {}
7        self.backup = {}
8
9        for name, param in model.named_parameters():
10            if param.requires_grad:
11                self.shadow[name] = param.data.clone()
12
13    def update(self, model):
14        for name, param in model.named_parameters():
15            if param.requires_grad:
16                self.shadow[name] = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
17
18    def apply_shadow(self, model):
19        for name, param in model.named_parameters():
20            if param.requires_grad:
21                self.backup[name] = param.data.clone()
22                param.data.copy_(self.shadow[name])
23
24    def restore(self, model):
25        for name, param in model.named_parameters():
26            if param.requires_grad:
27                param.data.copy_(self.backup[name])

Integration with Training Loop

🐍python
1def train_with_ema(
2    model: nn.Module,
3    train_loader: DataLoader,
4    test_loader: DataLoader,
5    optimizer: torch.optim.Optimizer,
6    criterion: nn.Module,
7    epochs: int,
8    ema_decay: float = 0.999,
9    device: torch.device = torch.device('cuda')
10) -> dict:
11    """
12    Training loop with EMA weight tracking.
13
14    Returns:
15        Dictionary with training history and final metrics
16    """
17    # Initialize EMA
18    ema = ExponentialMovingAverage(model, decay=ema_decay)
19
20    history = {'train_loss': [], 'val_loss': [], 'val_loss_ema': []}
21
22    for epoch in range(epochs):
23        # Training phase
24        model.train()
25        train_loss = 0.0
26
27        for batch in train_loader:
28            x, y = batch
29            x, y = x.to(device), y.to(device)
30
31            optimizer.zero_grad()
32            pred = model(x)
33            loss = criterion(pred, y)
34            loss.backward()
35            optimizer.step()
36
37            # Update EMA after each optimizer step
38            ema.update(model)
39
40            train_loss += loss.item()
41
42        # Validation phase - evaluate both raw and EMA weights
43        model.eval()
44
45        # Evaluate raw weights
46        val_loss_raw = evaluate(model, test_loader, criterion, device)
47
48        # Evaluate EMA weights
49        ema.apply_shadow(model)  # Temporarily use EMA weights
50        val_loss_ema = evaluate(model, test_loader, criterion, device)
51        ema.restore(model)  # Restore original weights for training
52
53        history['train_loss'].append(train_loss / len(train_loader))
54        history['val_loss'].append(val_loss_raw)
55        history['val_loss_ema'].append(val_loss_ema)
56
57        print(f"Epoch {epoch+1}: "
58              f"Val Loss (raw): {val_loss_raw:.4f}, "
59              f"Val Loss (EMA): {val_loss_ema:.4f}")
60
61    # Final model uses EMA weights
62    ema.apply_shadow(model)
63
64    return history
65
66
67def evaluate(model, dataloader, criterion, device):
68    """Evaluate model on dataloader."""
69    total_loss = 0.0
70    with torch.no_grad():
71        for x, y in dataloader:
72            x, y = x.to(device), y.to(device)
73            pred = model(x)
74            total_loss += criterion(pred, y).item()
75    return total_loss / len(dataloader)

Key Usage Patterns

OperationWhen to UseEffect
ema.update(model)After each optimizer.step()Updates shadow weights
ema.apply_shadow(model)Before evaluationSwitches to EMA weights
ema.restore(model)After evaluationReturns to training weights

Never Train with EMA Weights

Always train with the original weights, not EMA weights. EMA is for evaluation and final deployment only. Training with EMA weights would break the averaging mechanism.


Summary

In this section, we covered Exponential Moving Average:

  1. EMA formula: θEMA=βθEMA+(1β)θ\theta_{\text{EMA}} = \beta \cdot \theta_{\text{EMA}} + (1-\beta) \cdot \theta
  2. Decay parameter: β = 0.999 averages over ~1000 steps
  3. Benefits: Noise reduction, implicit regularization, flatter minima
  4. Shadow weights: Separate copy for evaluation
  5. Typical improvement: 2-5% reduction in RMSE
ParameterValue
Decay (β)0.999
Effective window~1000 steps
Update frequencyAfter each optimizer.step()
UsageEvaluation and final deployment
Looking Ahead: EMA helps us find better weights, but we also need to know when to stop training. The next section covers early stopping with best weights—a technique for preventing overfitting by monitoring validation performance.

With EMA understood, we explore early stopping strategies.