Chapter 13
12 min read
Section 64 of 104

Early Stopping with Best Weights

Training Enhancements

Learning Objectives

By the end of this section, you will:

  1. Understand why early stopping prevents overfitting
  2. Configure patience and minimum delta parameters
  3. Implement best weight restoration for optimal checkpoints
  4. Choose appropriate stopping criteria for RUL prediction
  5. Integrate early stopping with other training enhancements
Why This Matters: Neural networks can memorize training data if trained too long, leading to poor generalization. Early stopping monitors validation performance and halts training when the model stops improvingβ€”automatically finding the sweet spot between underfitting and overfitting.

Why Early Stopping?

Early stopping is one of the most effective regularization techniques for deep learning.

The Overfitting Problem

Training and validation loss typically follow different trajectories:

πŸ“text
1Training vs. Validation Loss Over Time:
2
3Loss
4  β”‚
5  β”‚   β•²
6  β”‚    β•²   Training Loss
7  β”‚     ╲─────────────────────────────→
8  β”‚      β•²
9  β”‚       β•²
10  β”‚        β•²___
11  β”‚            β•²__
12  β”‚               β•²___
13  β”‚         ╭──────────────────────────
14  β”‚        β•±     Validation Loss
15  β”‚       β•±        (starts rising!)
16  β”‚      β•±
17  β”‚     β”‚   ⬆️ Early stopping point
18  β”‚     β”‚
19  └─────┴────────────────────────────→ Epochs
20        β”‚
21    Optimalβ”‚
22     Stop β”‚
23
24Before this point: underfitting (both losses high)
25At this point: optimal generalization
26After this point: overfitting (val loss rises)

When to Stop

The goal is to stop training when validation performance stops improving. This requires:

  1. Monitoring validation loss (or another metric like RMSE)
  2. Detecting when improvement stops for a sustained period
  3. Saving the best weights encountered during training
  4. Restoring best weights when training ends

Early Stopping vs. Other Regularization

TechniqueHow It WorksComputational Cost
Early stoppingStop before overfittingSaves time (shorter training)
Weight decayPenalize large weightsMinimal overhead
DropoutRandom neuron maskingSlight overhead
Data augmentationIncrease data diversityData processing cost

Early Stopping Saves Time

Unlike other regularization techniques that add cost, early stopping actually reduces training time by ending training early. It is the only regularization technique that speeds up training.


Patience and Minimum Delta

Two key parameters control early stopping behavior.

Patience

Patience is the number of epochs to wait for improvement before stopping:

PatienceBehaviorUse Case
10Aggressive stoppingQuick experiments, small datasets
20-30ModerateStandard training
50-80ConservativeLong training, complex models
100+Very conservativeLarge models, slow convergence

Minimum Delta (min_delta)

Minimum delta defines the threshold for what counts as "improvement":

improvement=Lbestβˆ’Lcurrent>Ξ΄min⁑\text{improvement} = \mathcal{L}_{\text{best}} - \mathcal{L}_{\text{current}} > \delta_{\min}

Setting min_delta > 0 prevents stopping on tiny, meaningless improvements:

min_deltaEffectRecommendation
0Any improvement countsMay stop prematurely on noise
0.0001Ignore tiny improvementsStandard choice
0.001Only significant improvementsConservative
0.01Only major improvementsVery conservative

AMNL Settings

For RUL prediction with AMNL, we use patience = 80 and min_delta = 0.0001. This conservative setting allows the model time to escape plateaus while ignoring noise.


Restoring Best Weights

The key insight: training often continues past the optimal point before stopping.

Why Restore Best Weights?

πŸ“text
1Validation Loss During Training:
2
3Loss
4  β”‚
5  β”‚         Best weights
6  β”‚              ↓
7  β”‚   ───────────●───────────────────
8  β”‚              β”‚    ╭──────────────
9  β”‚              β”‚   β•±   Validation loss
10  β”‚              β”‚  β•±    starts rising
11  β”‚              β”‚ β•±
12  β”‚              β”‚β•±
13  β”‚              β•±  ← Patience period starts
14  β”‚             β•±
15  β”‚            β•±     (20 epochs of no improvement)
16  β”‚           β•±
17  β”‚          β•±   Training stops here
18  β”‚         β•±    but this is NOT the best!
19  β”‚        ●─────────────────────────────→
20  └─────────────────────────────────────→ Epochs
21            β”‚                β”‚
22        Best epoch      Stop epoch
23
24Without restore: Use weights at stop epoch (suboptimal)
25With restore: Use weights at best epoch (optimal)

Implementation Strategy

At each epoch, if validation performance improves:

  1. Save a deep copy of model.state_dict()
  2. Update best_loss to current loss
  3. Reset patience counter to 0

When training stops:

  1. Load the saved best weights back into the model
  2. Use these restored weights for evaluation and deployment

Deep Copy Required

You must use copy.deepcopy(model.state_dict()) when saving best weights. A shallow copy would be a reference to the same tensor objects, which continue to be modified during training.


Implementation

Our research implementation provides a clean, efficient early stopping class with best weight restoration.

AMNL Research Implementation

EarlyStopping Class
🐍enhanced_train_nasa_cmapss_sota_v7.py
1Class Definition

Enhanced early stopping that monitors validation loss and restores best weights when training ends.

4Patience Parameter

Number of epochs to wait without improvement before stopping. We use 80 epochs for AMNL to allow recovery from plateaus.

EXAMPLE
patience=80: Wait 80 epochs without improvement
5Minimum Delta

Threshold for what counts as 'improvement'. Prevents stopping on tiny, noise-driven improvements.

EXAMPLE
min_delta=0.001: Ignore improvements < 0.001
6Restore Best Weights

When True, loads the saved best weights when training stops. Critical for optimal final model.

12Call Method

Called at the end of each epoch with validation loss. Returns True when training should stop.

13Improvement Check

val_loss must be at least min_delta better than best_loss to count as improvement.

16Save Best Weights

Deep copy of model weights when improvement detected. Essential for restoration at the end.

21Patience Check

When counter reaches patience without improvement, training should stop.

23Weight Restoration

Load the saved best weights before returning True. Final model uses optimal checkpoint.

16 lines without explanation
1class EarlyStopping:
2    """Enhanced early stopping with multiple criteria"""
3
4    def __init__(self, patience=20, min_delta=0.001, restore_best_weights=True):
5        self.patience = patience
6        self.min_delta = min_delta
7        self.restore_best_weights = restore_best_weights
8        self.best_loss = float('inf')
9        self.counter = 0
10        self.best_weights = None
11
12    def __call__(self, val_loss, model):
13        if val_loss < self.best_loss - self.min_delta:
14            self.best_loss = val_loss
15            self.counter = 0
16            if self.restore_best_weights:
17                self.best_weights = copy.deepcopy(model.state_dict())
18        else:
19            self.counter += 1
20
21        if self.counter >= self.patience:
22            if self.restore_best_weights and self.best_weights:
23                model.load_state_dict(self.best_weights)
24            return True
25        return False

Integration with Training Loop

🐍python
1def train_with_early_stopping(
2    model: nn.Module,
3    train_loader: DataLoader,
4    val_loader: DataLoader,
5    optimizer: torch.optim.Optimizer,
6    criterion: nn.Module,
7    epochs: int,
8    patience: int = 80,
9    min_delta: float = 0.0001,
10    device: torch.device = torch.device('cuda')
11) -> dict:
12    """
13    Training loop with early stopping.
14
15    Returns:
16        Dictionary with training history and best epoch
17    """
18    early_stopping = EarlyStopping(
19        patience=patience,
20        min_delta=min_delta,
21        restore_best_weights=True
22    )
23
24    history = {
25        'train_loss': [],
26        'val_loss': [],
27        'best_epoch': -1
28    }
29
30    for epoch in range(epochs):
31        # Training phase
32        model.train()
33        train_loss = 0.0
34        for batch in train_loader:
35            x, y = batch
36            x, y = x.to(device), y.to(device)
37
38            optimizer.zero_grad()
39            pred = model(x)
40            loss = criterion(pred, y)
41            loss.backward()
42            optimizer.step()
43
44            train_loss += loss.item()
45
46        avg_train_loss = train_loss / len(train_loader)
47
48        # Validation phase
49        model.eval()
50        val_loss = 0.0
51        with torch.no_grad():
52            for batch in val_loader:
53                x, y = batch
54                x, y = x.to(device), y.to(device)
55                pred = model(x)
56                val_loss += criterion(pred, y).item()
57
58        avg_val_loss = val_loss / len(val_loader)
59
60        # Record history
61        history['train_loss'].append(avg_train_loss)
62        history['val_loss'].append(avg_val_loss)
63
64        # Track best epoch
65        if avg_val_loss < early_stopping.best_loss:
66            history['best_epoch'] = epoch
67
68        # Early stopping check
69        if early_stopping(avg_val_loss, model):
70            print(f"Early stopping triggered at epoch {epoch + 1}")
71            print(f"Best epoch was {history['best_epoch'] + 1} "
72                  f"with val_loss: {early_stopping.best_loss:.4f}")
73            break
74
75        # Progress logging
76        if (epoch + 1) % 10 == 0:
77            print(f"Epoch {epoch + 1}/{epochs} | "
78                  f"Train: {avg_train_loss:.4f} | "
79                  f"Val: {avg_val_loss:.4f} | "
80                  f"Patience: {early_stopping.counter}/{patience}")
81
82    return history

Using RMSE Instead of Loss

For RUL prediction, we typically monitor RMSE rather than raw loss:

🐍python
1class EarlyStoppingRMSE:
2    """
3    Early stopping based on validation RMSE.
4
5    For RUL prediction, RMSE is more interpretable than raw loss.
6    """
7
8    def __init__(
9        self,
10        patience: int = 80,
11        min_delta: float = 0.01,  # RMSE units (e.g., 0.01 cycles)
12        restore_best_weights: bool = True
13    ):
14        self.patience = patience
15        self.min_delta = min_delta
16        self.restore_best_weights = restore_best_weights
17
18        self.best_rmse = float('inf')
19        self.counter = 0
20        self.best_weights = None
21
22    def __call__(self, val_rmse: float, model: nn.Module) -> bool:
23        """
24        Check if training should stop based on RMSE.
25
26        Args:
27            val_rmse: Current validation RMSE
28            model: Model to potentially save/restore weights
29
30        Returns:
31            True if training should stop, False otherwise
32        """
33        if val_rmse < self.best_rmse - self.min_delta:
34            self.best_rmse = val_rmse
35            self.counter = 0
36
37            if self.restore_best_weights:
38                self.best_weights = copy.deepcopy(model.state_dict())
39        else:
40            self.counter += 1
41
42        if self.counter >= self.patience:
43            if self.restore_best_weights and self.best_weights:
44                model.load_state_dict(self.best_weights)
45            return True
46
47        return False

Summary

In this section, we covered early stopping with best weights:

  1. Purpose: Prevent overfitting by stopping when validation performance degrades
  2. Patience: Number of epochs to wait before stopping (80 for AMNL)
  3. min_delta: Threshold for meaningful improvement (0.0001)
  4. Best weights: Deep copy of weights at best validation performance
  5. Restoration: Load best weights when training ends
ParameterValue
Patience80 epochs
min_delta0.0001
Restore best weightsYes
Monitoring metricValidation RMSE
Looking Ahead: Early stopping tells us when to stop. The next section covers mixed precision trainingβ€”a technique that uses 16-bit floats to speed up training by 2-3Γ— while maintaining accuracy.

With early stopping configured, we explore mixed precision training.