Learning Objectives
By the end of this section, you will:
- Master EMA mathematics for loss tracking
- Understand bias correction for early training
- Apply EMA normalization to multi-task losses
- Choose the smoothing parameter β appropriately
- Implement production-ready EMA normalizers
Why This Matters: Exponential Moving Average (EMA) is the foundation of AMNL's loss normalization. By tracking running averages of loss magnitudes, EMA enables scale-invariant multi-task learning without expensive per-batch computations. Understanding EMA deeply is essential for implementing and extending AMNL.
EMA Fundamentals
EMA provides smooth, adaptive estimates of running statistics.
Definition
Where:
- : EMA at step t
- : Current observation (loss value)
- : Smoothing factor (typically 0.99)
Equivalent Formulations
EMA can be viewed in several equivalent ways:
| View | Formula | Interpretation |
|---|---|---|
| Recursive | μ = βμ + (1-β)x | Blend old with new |
| Weighted sum | μ = Σ(1-β)βⁱxₜ₋ᵢ | Exponentially decaying weights |
| Effective window | ~1/(1-β) steps | Memory length |
Effective Window Size
Bias Correction
EMA is biased toward initialization in early steps. Bias correction fixes this.
The Bias Problem
With zero initialization (μ⁽⁰⁾ = 0):
1Step 1: μ = 0.99 × 0 + 0.01 × 1000 = 10 (true avg ~1000)
2Step 2: μ = 0.99 × 10 + 0.01 × 1000 = 19.9 (true avg ~1000)
3Step 10: μ ≈ 95.6 (still far from 1000)
4Step 100: μ ≈ 634 (approaching 1000)
5
6Problem: EMA severely underestimates in early trainingBias Correction Formula
This corrects for the geometric sum of missing contributions.
Multi-Task Normalization
EMA enables scale-invariant loss combination in AMNL.
The Scale Problem
1Typical raw loss magnitudes:
2 L_RUL ≈ 100 - 2000 (MSE on cycles)
3 L_health ≈ 0.5 - 2.0 (cross-entropy)
4
5Ratio: ~100-1000× difference
6
7With λ = 0.5 for both (unnormalized):
8 Total = 0.5 × 1000 + 0.5 × 1.0 = 500.5
9 RUL contributes 99.9% of gradient!EMA Normalization Solution
Each loss is divided by its EMA, bringing all losses to approximately 1:
1After EMA normalization:
2 L̃_RUL = 1000 / 1000 = 1.0
3 L̃_health = 1.0 / 1.0 = 1.0
4
5With λ = 0.5 for both (normalized):
6 Total = 0.5 × 1.0 + 0.5 × 1.0 = 1.0
7 Each task contributes exactly 50%Gradient Scaling Effect
The gradient is scaled inversely to the EMA. Large losses (high μ) produce smaller gradients; small losses (low μ) produce larger gradients. This automatically balances gradient contributions.
Key Insight
EMA normalization is the mechanism that makes equal weights (0.5/0.5) actually mean equal contribution. Without it, "equal weights" would be meaningless due to scale differences.
Complete Implementation
Production-ready EMA normalizer with all features.
EMA Normalizer Class
1class EMANormalizer:
2 """
3 Exponential Moving Average normalizer for loss scaling.
4
5 Tracks running average of loss magnitudes to enable
6 scale-invariant loss normalization. Includes bias
7 correction for accurate early-training estimates.
8
9 Attributes:
10 beta: Smoothing factor (0.99 recommended)
11 ema: Current (uncorrected) EMA value
12 steps: Number of update steps
13 min_value: Minimum normalizer value (prevents division issues)
14 """
15
16 def __init__(
17 self,
18 beta: float = 0.99,
19 min_value: float = 1e-8
20 ):
21 """
22 Initialize EMA normalizer.
23
24 Args:
25 beta: Smoothing factor. Higher = more smoothing.
26 min_value: Floor value to prevent division by zero.
27 """
28 if not 0 <= beta < 1:
29 raise ValueError(f"beta must be in [0, 1), got {beta}")
30
31 self.beta = beta
32 self.min_value = min_value
33 self.ema: Optional[float] = None
34 self.steps = 0
35
36 def update(self, value: float) -> float:
37 """
38 Update EMA with new value and return bias-corrected estimate.
39
40 Args:
41 value: Current observation (loss value). Should be detached
42 scalar (use .item() on tensor).
43
44 Returns:
45 Bias-corrected EMA for use as normalizer.
46 """
47 self.steps += 1
48
49 if self.ema is None:
50 # First observation: initialize directly
51 self.ema = value
52 else:
53 # Standard EMA update
54 self.ema = self.beta * self.ema + (1 - self.beta) * value
55
56 # Bias correction
57 corrected = self.ema / (1 - self.beta ** self.steps)
58
59 # Apply floor
60 return max(corrected, self.min_value)
61
62 def get_value(self) -> float:
63 """
64 Get current bias-corrected EMA without updating.
65
66 Returns:
67 Current bias-corrected EMA, or 1.0 if not initialized.
68 """
69 if self.ema is None or self.steps == 0:
70 return 1.0
71
72 corrected = self.ema / (1 - self.beta ** self.steps)
73 return max(corrected, self.min_value)
74
75 def reset(self) -> None:
76 """Reset normalizer to initial state."""
77 self.ema = None
78 self.steps = 0
79
80 @property
81 def effective_window(self) -> float:
82 """Approximate number of steps in effective memory."""
83 return 1.0 / (1.0 - self.beta)
84
85 def __repr__(self) -> str:
86 return (
87 f"EMANormalizer(beta={self.beta}, "
88 f"steps={self.steps}, "
89 f"value={self.get_value():.4f})"
90 )Multi-Task Loss with EMA
1class EMANormalizedMultiTaskLoss(nn.Module):
2 """
3 Multi-task loss with EMA-based normalization.
4
5 Automatically normalizes losses from different tasks
6 to have similar magnitudes, enabling meaningful
7 weighted combination.
8
9 Args:
10 task_names: Names of tasks for logging
11 task_weights: Weight for each task (default: equal)
12 beta: EMA smoothing factor
13 """
14
15 def __init__(
16 self,
17 task_names: List[str],
18 task_weights: Optional[List[float]] = None,
19 beta: float = 0.99
20 ):
21 super().__init__()
22 self.task_names = task_names
23 self.n_tasks = len(task_names)
24
25 # Default to equal weights
26 if task_weights is None:
27 task_weights = [1.0 / self.n_tasks] * self.n_tasks
28 self.task_weights = task_weights
29
30 # Create EMA normalizer for each task
31 self.normalizers = {
32 name: EMANormalizer(beta=beta)
33 for name in task_names
34 }
35
36 def forward(
37 self,
38 losses: Dict[str, torch.Tensor]
39 ) -> Tuple[torch.Tensor, Dict[str, float]]:
40 """
41 Compute normalized multi-task loss.
42
43 Args:
44 losses: Dictionary mapping task name to loss tensor
45
46 Returns:
47 Tuple of:
48 - Total normalized loss (scalar tensor)
49 - Metrics dictionary for logging
50 """
51 total_loss = 0.0
52 metrics = {}
53
54 for i, name in enumerate(self.task_names):
55 raw_loss = losses[name]
56
57 # Update EMA and get normalizer
58 normalizer = self.normalizers[name].update(raw_loss.item())
59
60 # Normalize loss
61 normalized_loss = raw_loss / normalizer
62
63 # Weight and accumulate
64 weighted_loss = self.task_weights[i] * normalized_loss
65 total_loss = total_loss + weighted_loss
66
67 # Record metrics
68 metrics[f"loss/{name}_raw"] = raw_loss.item()
69 metrics[f"loss/{name}_normalized"] = normalized_loss.item()
70 metrics[f"ema/{name}"] = normalizer
71
72 metrics["loss/total"] = total_loss.item() if isinstance(total_loss, torch.Tensor) else total_loss
73
74 return total_loss, metrics
75
76 def reset(self) -> None:
77 """Reset all EMA normalizers."""
78 for normalizer in self.normalizers.values():
79 normalizer.reset()Usage in AMNL
1# Initialize for AMNL
2amnl = EMANormalizedMultiTaskLoss(
3 task_names=["rul", "health"],
4 task_weights=[0.5, 0.5], # Equal weights
5 beta=0.99
6)
7
8# Training loop
9for batch in dataloader:
10 # Compute individual losses
11 rul_loss = rul_criterion(pred_rul, target_rul)
12 health_loss = health_criterion(pred_health, target_health)
13
14 # Combine with EMA normalization
15 total_loss, metrics = amnl({
16 "rul": rul_loss,
17 "health": health_loss
18 })
19
20 # Backprop and optimize
21 optimizer.zero_grad()
22 total_loss.backward()
23 optimizer.step()
24
25 # Log metrics
26 for key, value in metrics.items():
27 logger.log(key, value)Summary
In this section, we covered EMA-based adaptive scaling:
- EMA formula:
- Bias correction:
- Normalization:
- Effect: Scale-invariant loss combination
- Recommended β: 0.99 (~100-step window)
| Parameter | Value | Effect |
|---|---|---|
| β | 0.99 | ~100 steps memory |
| Bias correction | Yes | Accurate from step 1 |
| Min value | 1e-8 | Numerical stability |
| Effective window | ~100 steps | 1-2 epochs |
Chapter Complete: You now have a complete toolkit of advanced loss components: weighted MSE, asymmetric penalties, focal loss, and EMA normalization. The next chapter covers optimization strategy—how to configure optimizers, learning rate schedules, and regularization for stable training.
With all loss components mastered, we proceed to optimization techniques.