Chapter 11
15 min read
Section 57 of 104

EMA-Based Adaptive Scaling

Advanced Loss Components

Learning Objectives

By the end of this section, you will:

  1. Master EMA mathematics for loss tracking
  2. Understand bias correction for early training
  3. Apply EMA normalization to multi-task losses
  4. Choose the smoothing parameter β appropriately
  5. Implement production-ready EMA normalizers
Why This Matters: Exponential Moving Average (EMA) is the foundation of AMNL's loss normalization. By tracking running averages of loss magnitudes, EMA enables scale-invariant multi-task learning without expensive per-batch computations. Understanding EMA deeply is essential for implementing and extending AMNL.

EMA Fundamentals

EMA provides smooth, adaptive estimates of running statistics.

Definition

μ(t)=βμ(t1)+(1β)x(t)\mu^{(t)} = \beta \cdot \mu^{(t-1)} + (1 - \beta) \cdot x^{(t)}

Where:

  • μ(t)\mu^{(t)}: EMA at step t
  • x(t)x^{(t)}: Current observation (loss value)
  • β[0,1)\beta \in [0, 1): Smoothing factor (typically 0.99)

Equivalent Formulations

EMA can be viewed in several equivalent ways:

ViewFormulaInterpretation
Recursiveμ = βμ + (1-β)xBlend old with new
Weighted sumμ = Σ(1-β)βⁱxₜ₋ᵢExponentially decaying weights
Effective window~1/(1-β) stepsMemory length

Effective Window Size


Bias Correction

EMA is biased toward initialization in early steps. Bias correction fixes this.

The Bias Problem

With zero initialization (μ⁽⁰⁾ = 0):

📝text
1Step 1: μ = 0.99 × 0 + 0.01 × 1000 = 10      (true avg ~1000)
2Step 2: μ = 0.99 × 10 + 0.01 × 1000 = 19.9   (true avg ~1000)
3Step 10: μ ≈ 95.6                             (still far from 1000)
4Step 100: μ ≈ 634                             (approaching 1000)
5
6Problem: EMA severely underestimates in early training

Bias Correction Formula

μ^(t)=μ(t)1βt\hat{\mu}^{(t)} = \frac{\mu^{(t)}}{1 - \beta^t}

This corrects for the geometric sum of missing contributions.


Multi-Task Normalization

EMA enables scale-invariant loss combination in AMNL.

The Scale Problem

📝text
1Typical raw loss magnitudes:
2  L_RUL   ≈ 100 - 2000  (MSE on cycles)
3  L_health ≈ 0.5 - 2.0  (cross-entropy)
4
5Ratio: ~100-1000× difference
6
7With λ = 0.5 for both (unnormalized):
8  Total = 0.5 × 1000 + 0.5 × 1.0 = 500.5
9  RUL contributes 99.9% of gradient!

EMA Normalization Solution

Lnormalized=Lμ^\mathcal{L}_{\text{normalized}} = \frac{\mathcal{L}}{\hat{\mu}}

Each loss is divided by its EMA, bringing all losses to approximately 1:

📝text
1After EMA normalization:
2  L̃_RUL = 1000 / 1000 = 1.0
3  L̃_health = 1.0 / 1.0 = 1.0
4
5With λ = 0.5 for both (normalized):
6  Total = 0.5 × 1.0 + 0.5 × 1.0 = 1.0
7  Each task contributes exactly 50%

Gradient Scaling Effect

θLμ=1μθL\nabla_\theta \frac{\mathcal{L}}{\mu} = \frac{1}{\mu} \nabla_\theta \mathcal{L}

The gradient is scaled inversely to the EMA. Large losses (high μ) produce smaller gradients; small losses (low μ) produce larger gradients. This automatically balances gradient contributions.

Key Insight

EMA normalization is the mechanism that makes equal weights (0.5/0.5) actually mean equal contribution. Without it, "equal weights" would be meaningless due to scale differences.


Complete Implementation

Production-ready EMA normalizer with all features.

EMA Normalizer Class

🐍python
1class EMANormalizer:
2    """
3    Exponential Moving Average normalizer for loss scaling.
4
5    Tracks running average of loss magnitudes to enable
6    scale-invariant loss normalization. Includes bias
7    correction for accurate early-training estimates.
8
9    Attributes:
10        beta: Smoothing factor (0.99 recommended)
11        ema: Current (uncorrected) EMA value
12        steps: Number of update steps
13        min_value: Minimum normalizer value (prevents division issues)
14    """
15
16    def __init__(
17        self,
18        beta: float = 0.99,
19        min_value: float = 1e-8
20    ):
21        """
22        Initialize EMA normalizer.
23
24        Args:
25            beta: Smoothing factor. Higher = more smoothing.
26            min_value: Floor value to prevent division by zero.
27        """
28        if not 0 <= beta < 1:
29            raise ValueError(f"beta must be in [0, 1), got {beta}")
30
31        self.beta = beta
32        self.min_value = min_value
33        self.ema: Optional[float] = None
34        self.steps = 0
35
36    def update(self, value: float) -> float:
37        """
38        Update EMA with new value and return bias-corrected estimate.
39
40        Args:
41            value: Current observation (loss value). Should be detached
42                   scalar (use .item() on tensor).
43
44        Returns:
45            Bias-corrected EMA for use as normalizer.
46        """
47        self.steps += 1
48
49        if self.ema is None:
50            # First observation: initialize directly
51            self.ema = value
52        else:
53            # Standard EMA update
54            self.ema = self.beta * self.ema + (1 - self.beta) * value
55
56        # Bias correction
57        corrected = self.ema / (1 - self.beta ** self.steps)
58
59        # Apply floor
60        return max(corrected, self.min_value)
61
62    def get_value(self) -> float:
63        """
64        Get current bias-corrected EMA without updating.
65
66        Returns:
67            Current bias-corrected EMA, or 1.0 if not initialized.
68        """
69        if self.ema is None or self.steps == 0:
70            return 1.0
71
72        corrected = self.ema / (1 - self.beta ** self.steps)
73        return max(corrected, self.min_value)
74
75    def reset(self) -> None:
76        """Reset normalizer to initial state."""
77        self.ema = None
78        self.steps = 0
79
80    @property
81    def effective_window(self) -> float:
82        """Approximate number of steps in effective memory."""
83        return 1.0 / (1.0 - self.beta)
84
85    def __repr__(self) -> str:
86        return (
87            f"EMANormalizer(beta={self.beta}, "
88            f"steps={self.steps}, "
89            f"value={self.get_value():.4f})"
90        )

Multi-Task Loss with EMA

🐍python
1class EMANormalizedMultiTaskLoss(nn.Module):
2    """
3    Multi-task loss with EMA-based normalization.
4
5    Automatically normalizes losses from different tasks
6    to have similar magnitudes, enabling meaningful
7    weighted combination.
8
9    Args:
10        task_names: Names of tasks for logging
11        task_weights: Weight for each task (default: equal)
12        beta: EMA smoothing factor
13    """
14
15    def __init__(
16        self,
17        task_names: List[str],
18        task_weights: Optional[List[float]] = None,
19        beta: float = 0.99
20    ):
21        super().__init__()
22        self.task_names = task_names
23        self.n_tasks = len(task_names)
24
25        # Default to equal weights
26        if task_weights is None:
27            task_weights = [1.0 / self.n_tasks] * self.n_tasks
28        self.task_weights = task_weights
29
30        # Create EMA normalizer for each task
31        self.normalizers = {
32            name: EMANormalizer(beta=beta)
33            for name in task_names
34        }
35
36    def forward(
37        self,
38        losses: Dict[str, torch.Tensor]
39    ) -> Tuple[torch.Tensor, Dict[str, float]]:
40        """
41        Compute normalized multi-task loss.
42
43        Args:
44            losses: Dictionary mapping task name to loss tensor
45
46        Returns:
47            Tuple of:
48                - Total normalized loss (scalar tensor)
49                - Metrics dictionary for logging
50        """
51        total_loss = 0.0
52        metrics = {}
53
54        for i, name in enumerate(self.task_names):
55            raw_loss = losses[name]
56
57            # Update EMA and get normalizer
58            normalizer = self.normalizers[name].update(raw_loss.item())
59
60            # Normalize loss
61            normalized_loss = raw_loss / normalizer
62
63            # Weight and accumulate
64            weighted_loss = self.task_weights[i] * normalized_loss
65            total_loss = total_loss + weighted_loss
66
67            # Record metrics
68            metrics[f"loss/{name}_raw"] = raw_loss.item()
69            metrics[f"loss/{name}_normalized"] = normalized_loss.item()
70            metrics[f"ema/{name}"] = normalizer
71
72        metrics["loss/total"] = total_loss.item() if isinstance(total_loss, torch.Tensor) else total_loss
73
74        return total_loss, metrics
75
76    def reset(self) -> None:
77        """Reset all EMA normalizers."""
78        for normalizer in self.normalizers.values():
79            normalizer.reset()

Usage in AMNL

🐍python
1# Initialize for AMNL
2amnl = EMANormalizedMultiTaskLoss(
3    task_names=["rul", "health"],
4    task_weights=[0.5, 0.5],  # Equal weights
5    beta=0.99
6)
7
8# Training loop
9for batch in dataloader:
10    # Compute individual losses
11    rul_loss = rul_criterion(pred_rul, target_rul)
12    health_loss = health_criterion(pred_health, target_health)
13
14    # Combine with EMA normalization
15    total_loss, metrics = amnl({
16        "rul": rul_loss,
17        "health": health_loss
18    })
19
20    # Backprop and optimize
21    optimizer.zero_grad()
22    total_loss.backward()
23    optimizer.step()
24
25    # Log metrics
26    for key, value in metrics.items():
27        logger.log(key, value)

Summary

In this section, we covered EMA-based adaptive scaling:

  1. EMA formula: μ=βμ+(1β)x\mu = \beta\mu + (1-\beta)x
  2. Bias correction: μ^=μ/(1βt)\hat{\mu} = \mu/(1-\beta^t)
  3. Normalization: L~=L/μ^\tilde{\mathcal{L}} = \mathcal{L}/\hat{\mu}
  4. Effect: Scale-invariant loss combination
  5. Recommended β: 0.99 (~100-step window)
ParameterValueEffect
β0.99~100 steps memory
Bias correctionYesAccurate from step 1
Min value1e-8Numerical stability
Effective window~100 steps1-2 epochs
Chapter Complete: You now have a complete toolkit of advanced loss components: weighted MSE, asymmetric penalties, focal loss, and EMA normalization. The next chapter covers optimization strategy—how to configure optimizers, learning rate schedules, and regularization for stable training.

With all loss components mastered, we proceed to optimization techniques.