Chapter 14
15 min read
Section 69 of 104

Validation and Model Selection

Complete Training Script

Learning Objectives

By the end of this section, you will:

  1. Understand validation strategies for RUL prediction
  2. Choose appropriate selection metrics (RMSE vs. NASA score)
  3. Implement comprehensive evaluation during training
  4. Balance RUL and health classification metrics
  5. Use EMA weights for evaluation
Why This Matters: The model checkpoint you select for deployment determines real-world performance. Choosing the right validation metric and evaluation strategy ensures you deploy the model that generalizes best, not just the one that happened to perform well on one epoch.

Validation Strategy

C-MAPSS datasets have a specific train/test split that we must respect.

Dataset Structure

The NASA C-MAPSS benchmark provides fixed train/test splits:

DatasetTrain EnginesTest EnginesTrain SamplesTest Samples
FD001100100~17,700~100
FD002260259~48,800~259
FD003100100~21,800~100
FD004249248~57,500~248

No Validation Split from Training

Unlike typical machine learning, we do not create a validation split from training data because:

  1. The benchmark requires using all training data for training
  2. Test set represents the "real" validation (last-cycle evaluation)
  3. Published results use the full train set

Evaluation Protocol

We evaluate on the test set after each epoch to monitor progress and select the best model. This is standard practice for the C-MAPSS benchmark and allows comparison with published results.

Last-Cycle vs. All-Cycles Evaluation

Test evaluation can use different subsets of predictions:

ModeWhat It MeasuresUse Case
Last-cycleFinal prediction per enginePrimary benchmark metric
All-cyclesAll predictions during operationModel consistency
RMSElast=1Ni=1N(yilasty^ilast)2\text{RMSE}_{\text{last}} = \sqrt{\frac{1}{N} \sum_{i=1}^{N} (y_i^{\text{last}} - \hat{y}_i^{\text{last}})^2}

Where N is the number of test engines, and "last" refers to the final operating cycle of each engine.


Model Selection Criteria

Which metric should drive model selection?

Primary Metric: RMSE (Last-Cycle)

We select models based on last-cycle RMSE for several reasons:

  • Interpretable: Units are cycles, directly meaningful
  • Comparable: Standard metric across all published papers
  • Stable: Less sensitive to outliers than NASA score
  • Differentiable: Continuous improvements are visible

Secondary Metric: NASA Score

The NASA asymmetric scoring function penalizes late predictions more heavily:

S=i=1N{edi/131if di<0 (early)edi/101if di0 (late)S = \sum_{i=1}^{N} \begin{cases} e^{-d_i/13} - 1 & \text{if } d_i < 0 \text{ (early)} \\ e^{d_i/10} - 1 & \text{if } d_i \geq 0 \text{ (late)} \end{cases}

Where di=y^iyid_i = \hat{y}_i - y_i is the prediction error.

NASA Score Volatility

NASA score can swing dramatically due to a single bad prediction (exponential penalty). While important for final evaluation, it is too unstable for model selection during training.

Health Classification Accuracy

For the dual-task model, we also track health state classification:

MetricPurposeTarget
AccuracyOverall correctness>80%
F1 ScoreClass-balanced performance>75%
Per-class recallCritical state detection>70% for critical

Comprehensive Evaluation

A complete evaluation covers multiple aspects of model performance.

Metrics Computed Each Epoch

🐍python
1def evaluate_model_comprehensive(
2    model: nn.Module,
3    test_dataset,
4    device: torch.device
5) -> dict:
6    """
7    Comprehensive model evaluation.
8
9    Computes all metrics needed for model selection and analysis.
10
11    Returns:
12        Dictionary with all evaluation metrics
13    """
14    model.eval()
15
16    all_preds = []
17    all_targets = []
18    all_health_preds = []
19    all_health_targets = []
20
21    with torch.no_grad():
22        for sequences, targets in DataLoader(test_dataset, batch_size=256):
23            sequences = sequences.to(device)
24
25            # Get predictions
26            rul_pred, health_pred = model(sequences)
27
28            all_preds.extend(rul_pred.cpu().numpy().flatten())
29            all_targets.extend(targets.numpy().flatten())
30
31            # Health predictions
32            health_preds = health_pred.argmax(dim=1).cpu().numpy()
33            health_targets = rul_to_health_state(targets.numpy())
34
35            all_health_preds.extend(health_preds)
36            all_health_targets.extend(health_targets)
37
38    preds = np.array(all_preds)
39    targets = np.array(all_targets)
40
41    # RUL Metrics
42    results = {
43        # All-cycles metrics
44        'RMSE_all_cycles': np.sqrt(np.mean((preds - targets) ** 2)),
45        'MAE_all_cycles': np.mean(np.abs(preds - targets)),
46        'R2_all_cycles': 1 - np.sum((targets - preds) ** 2) / np.sum((targets - targets.mean()) ** 2),
47
48        # Last-cycle metrics (one per engine)
49        'RMSE_last_cycle': compute_last_cycle_rmse(preds, targets, test_dataset),
50        'nasa_score_raw': compute_nasa_score(preds, targets),
51        'nasa_score_paper': compute_nasa_score_last_cycle(preds, targets, test_dataset),
52
53        # Health classification
54        'health_accuracy': accuracy_score(all_health_targets, all_health_preds) * 100,
55        'health_f1': f1_score(all_health_targets, all_health_preds, average='weighted') * 100,
56
57        # Prediction analysis
58        'early_predictions': np.sum(preds < targets),
59        'late_predictions': np.sum(preds > targets),
60        'n_total_predictions': len(preds)
61    }
62
63    return results

Last-Cycle RMSE Computation

🐍python
1def compute_last_cycle_rmse(
2    predictions: np.ndarray,
3    targets: np.ndarray,
4    dataset
5) -> float:
6    """
7    Compute RMSE using only the last cycle of each engine.
8
9    This is the primary benchmark metric for C-MAPSS.
10    """
11    # Get engine IDs and find last cycle for each
12    engine_ids = dataset.engine_ids
13
14    last_cycle_preds = []
15    last_cycle_targets = []
16
17    unique_engines = np.unique(engine_ids)
18    for engine in unique_engines:
19        mask = engine_ids == engine
20        engine_preds = predictions[mask]
21        engine_targets = targets[mask]
22
23        # Last prediction for this engine
24        last_cycle_preds.append(engine_preds[-1])
25        last_cycle_targets.append(engine_targets[-1])
26
27    last_cycle_preds = np.array(last_cycle_preds)
28    last_cycle_targets = np.array(last_cycle_targets)
29
30    return np.sqrt(np.mean((last_cycle_preds - last_cycle_targets) ** 2))

Implementation

Complete evaluation integration in the training loop.

Evaluation with EMA

🐍python
1def evaluate_with_ema(
2    model: nn.Module,
3    ema: ExponentialMovingAverage,
4    test_dataset,
5    device: torch.device
6) -> dict:
7    """
8    Evaluate model using EMA weights.
9
10    Temporarily replaces model weights with EMA weights,
11    evaluates, then restores original weights.
12    """
13    model.eval()
14
15    # Apply EMA weights
16    ema.apply_shadow(model)
17
18    # Comprehensive evaluation
19    results = evaluate_model_comprehensive(model, test_dataset, device)
20
21    # Restore training weights
22    ema.restore(model)
23
24    return results

Model Selection Logic

🐍python
1class ModelSelector:
2    """
3    Tracks best model based on validation metric.
4    """
5
6    def __init__(self, metric: str = 'RMSE_last_cycle', mode: str = 'min'):
7        self.metric = metric
8        self.mode = mode
9        self.best_value = float('inf') if mode == 'min' else float('-inf')
10        self.best_epoch = -1
11        self.best_state = None
12
13    def update(self, epoch: int, results: dict, model: nn.Module) -> bool:
14        """
15        Check if current model is best and save if so.
16
17        Returns True if this is a new best model.
18        """
19        current_value = results[self.metric]
20
21        is_better = (
22            (self.mode == 'min' and current_value < self.best_value) or
23            (self.mode == 'max' and current_value > self.best_value)
24        )
25
26        if is_better:
27            self.best_value = current_value
28            self.best_epoch = epoch
29            self.best_state = copy.deepcopy(model.state_dict())
30            return True
31
32        return False
33
34    def restore_best(self, model: nn.Module):
35        """Load best model weights."""
36        if self.best_state is not None:
37            model.load_state_dict(self.best_state)

Integration in Training Loop

🐍python
1# In the main training loop
2model_selector = ModelSelector(metric='RMSE_last_cycle', mode='min')
3
4for epoch in range(epochs):
5    # Training phase
6    train_results = train_epoch(model, train_loader, ...)
7
8    # Evaluation phase
9    if ema is not None:
10        eval_results = evaluate_with_ema(model, ema, test_dataset, device)
11    else:
12        eval_results = evaluate_model_comprehensive(model, test_dataset, device)
13
14    # Model selection
15    is_best = model_selector.update(epoch, eval_results, model)
16
17    if is_best:
18        print(f"New best model! RMSE: {eval_results['RMSE_last_cycle']:.2f}")
19
20    # Early stopping check
21    if early_stopping(eval_results['RMSE_last_cycle'], model):
22        break
23
24# Restore best model at end
25model_selector.restore_best(model)

Summary

In this section, we covered validation and model selection:

  1. C-MAPSS protocol: Fixed train/test splits, evaluate on test each epoch
  2. Primary metric: RMSE (last-cycle) for model selection
  3. Secondary metrics: NASA score, health accuracy, F1
  4. EMA evaluation: Use smoothed weights for evaluation
  5. Model selector: Track and restore best checkpoint
MetricPurposeSelection Criterion
RMSE (last-cycle)Primary benchmarkMinimize (model selection)
NASA ScoreAsymmetric penaltyReport for comparison
Health AccuracyClassification taskSecondary check
R² ScoreVariance explainedDiagnostic
Looking Ahead: With validation and model selection configured, we need to track training progress effectively. The next section covers training monitoring and logging.

With model selection understood, we explore monitoring and logging.