Chapter 15
12 min read
Section 76 of 104

Health Classification Metrics

Evaluation Metrics

Learning Objectives

By the end of this section, you will:

  1. Understand health state discretization from continuous RUL
  2. Compute classification metrics (accuracy, precision, recall, F1)
  3. Handle class imbalance in health state distribution
  4. Analyze per-class performance with confusion matrices
  5. Evaluate the secondary task in dual-task learning
Why This Matters: The AMNL model jointly predicts RUL (regression) and health state (classification). The health classification task provides regularization that improves RUL prediction. Evaluating both tasks gives a complete picture of model performance.

Health State Definition

Health states are derived from continuous RUL values using threshold discretization.

Three-Class Health State Model

StateRUL RangeMeaningAction
Healthy (0)RUL > 50Normal operationContinue monitoring
Degrading (1)15 < RUL ≤ 50Noticeable wearSchedule maintenance
Critical (2)RUL ≤ 15Imminent failureImmediate intervention

Discretization Function

🐍python
1def rul_to_health_state(rul_values: np.ndarray) -> np.ndarray:
2    """
3    Convert continuous RUL to discrete health states.
4
5    Health states:
6        0: Healthy    (RUL > 50)
7        1: Degrading  (15 < RUL <= 50)
8        2: Critical   (RUL <= 15)
9
10    Args:
11        rul_values: Array of RUL values
12
13    Returns:
14        Array of health state labels (0, 1, or 2)
15    """
16    health_states = np.zeros_like(rul_values, dtype=np.int64)
17
18    # Degrading: 15 < RUL <= 50
19    health_states[(rul_values > 15) & (rul_values <= 50)] = 1
20
21    # Critical: RUL <= 15
22    health_states[rul_values <= 15] = 2
23
24    return health_states

Class Distribution

The health state distribution varies across datasets:

DatasetHealthyDegradingCritical
FD001~48%~32%~20%
FD002~52%~30%~18%
FD003~47%~33%~20%
FD004~51%~31%~18%

Class Imbalance

The critical class is underrepresented (~18-20%) compared to healthy (~48-52%). This imbalance affects metric interpretation—high overall accuracy may hide poor critical-class performance.


Classification Metrics

Standard classification metrics for multi-class health prediction.

Accuracy

Accuracy=Correct PredictionsTotal Predictions=TP+TNTP+TN+FP+FN\text{Accuracy} = \frac{\text{Correct Predictions}}{\text{Total Predictions}} = \frac{TP + TN}{TP + TN + FP + FN}

For multi-class, this is the fraction of all predictions that are correct.

Precision, Recall, and F1

Per-class metrics provide deeper insight:

Precisionc=TPcTPc+FPc\text{Precision}_c = \frac{TP_c}{TP_c + FP_c}
Recallc=TPcTPc+FNc\text{Recall}_c = \frac{TP_c}{TP_c + FN_c}
F1c=2PrecisioncRecallcPrecisionc+Recallc\text{F1}_c = 2 \cdot \frac{\text{Precision}_c \cdot \text{Recall}_c}{\text{Precision}_c + \text{Recall}_c}
MetricMeaningHigh Value Indicates
PrecisionOf predicted class c, how many are correctFew false alarms
RecallOf actual class c, how many were foundFew missed cases
F1 ScoreHarmonic mean of precision and recallBalanced performance

Weighted F1 Score

For imbalanced classes, weighted F1 accounts for class frequency:

F1weighted=c=1CncNF1c\text{F1}_{\text{weighted}} = \sum_{c=1}^{C} \frac{n_c}{N} \cdot \text{F1}_c

Where ncn_c is the number of samples in class c, and N is the total number of samples.


Handling Class Imbalance

Strategies for evaluating and reporting with imbalanced classes.

Per-Class Analysis

Always report metrics for each class separately:

ClassPrecisionRecallF1Support
Healthy0.920.950.935,000
Degrading0.780.720.753,000
Critical0.850.810.832,000

Confusion Matrix

The confusion matrix reveals error patterns:

📝text
1Predicted
2                 Healthy  Degrading  Critical
3Actual Healthy     4750      200        50
4       Degrading    300     2160       540
5       Critical      50      330      1620
6
7Key insights:
8- Healthy well-detected (95% recall)
9- Degrading confused with Critical (18% → Critical)
10- Critical sometimes missed (16.5% → Degrading)

Critical Class Focus

For safety-critical applications, prioritize critical-class recall:

Critical Class Recall

Missing a critical health state (predicting healthy when actually critical) is dangerous. Target ≥80% critical recall even if it means lower precision (more false alarms are acceptable).


Implementation

Complete implementation for health classification evaluation.

Comprehensive Classification Metrics

🐍python
1from sklearn.metrics import (
2    accuracy_score,
3    precision_score,
4    recall_score,
5    f1_score,
6    confusion_matrix,
7    classification_report
8)
9
10def evaluate_health_classification(
11    predictions: np.ndarray,
12    targets: np.ndarray,
13    class_names: List[str] = ['Healthy', 'Degrading', 'Critical']
14) -> Dict:
15    """
16    Comprehensive evaluation of health state classification.
17
18    Args:
19        predictions: Predicted health states (0, 1, 2)
20        targets: True health states (0, 1, 2)
21        class_names: Names for each class
22
23    Returns:
24        Dictionary with all classification metrics
25    """
26    results = {}
27
28    # Overall accuracy
29    results['accuracy'] = float(accuracy_score(targets, predictions) * 100)
30
31    # Weighted F1 (accounts for class imbalance)
32    results['f1_weighted'] = float(
33        f1_score(targets, predictions, average='weighted') * 100
34    )
35
36    # Macro F1 (equal weight per class)
37    results['f1_macro'] = float(
38        f1_score(targets, predictions, average='macro') * 100
39    )
40
41    # Per-class metrics
42    precision_per_class = precision_score(
43        targets, predictions, average=None, zero_division=0
44    )
45    recall_per_class = recall_score(
46        targets, predictions, average=None, zero_division=0
47    )
48    f1_per_class = f1_score(
49        targets, predictions, average=None, zero_division=0
50    )
51
52    for i, name in enumerate(class_names):
53        results[f'precision_{name.lower()}'] = float(precision_per_class[i] * 100)
54        results[f'recall_{name.lower()}'] = float(recall_per_class[i] * 100)
55        results[f'f1_{name.lower()}'] = float(f1_per_class[i] * 100)
56
57    # Confusion matrix
58    results['confusion_matrix'] = confusion_matrix(targets, predictions).tolist()
59
60    # Class distribution
61    for i, name in enumerate(class_names):
62        results[f'support_{name.lower()}'] = int(np.sum(targets == i))
63
64    return results

Health Evaluation in Model Training

🐍python
1def evaluate_model_health_task(
2    model: nn.Module,
3    test_loader: DataLoader,
4    device: torch.device
5) -> Dict:
6    """
7    Evaluate health classification task during training.
8
9    Args:
10        model: Dual-task model
11        test_loader: Test data loader
12        device: Compute device
13
14    Returns:
15        Health classification metrics
16    """
17    model.eval()
18
19    all_health_preds = []
20    all_health_targets = []
21
22    with torch.no_grad():
23        for sequences, rul_targets in test_loader:
24            sequences = sequences.to(device)
25
26            # Get model predictions
27            rul_pred, health_logits = model(sequences)
28
29            # Convert logits to class predictions
30            health_preds = health_logits.argmax(dim=1).cpu().numpy()
31
32            # Convert RUL to health states for targets
33            health_targets = rul_to_health_state(rul_targets.numpy())
34
35            all_health_preds.extend(health_preds)
36            all_health_targets.extend(health_targets)
37
38    all_health_preds = np.array(all_health_preds)
39    all_health_targets = np.array(all_health_targets)
40
41    # Compute metrics
42    results = {
43        'health_accuracy': float(
44            accuracy_score(all_health_targets, all_health_preds) * 100
45        ),
46        'health_f1': float(
47            f1_score(all_health_targets, all_health_preds, average='weighted') * 100
48        ),
49    }
50
51    # Add critical-class recall (safety-critical)
52    critical_mask = all_health_targets == 2
53    if np.sum(critical_mask) > 0:
54        critical_correct = np.sum(
55            (all_health_preds == 2) & critical_mask
56        )
57        results['critical_recall'] = float(
58            critical_correct / np.sum(critical_mask) * 100
59        )
60
61    return results

Printing Classification Report

🐍python
1# Generate detailed classification report
2report = classification_report(
3    all_health_targets,
4    all_health_preds,
5    target_names=['Healthy', 'Degrading', 'Critical'],
6    digits=3
7)
8
9print("Health State Classification Report:")
10print(report)
11
12# Example output:
13#               precision    recall  f1-score   support
14#
15#      Healthy      0.921     0.950     0.935      5000
16#    Degrading      0.784     0.720     0.751      3000
17#     Critical      0.852     0.810     0.830      2000
18#
19#     accuracy                          0.862     10000
20#    macro avg      0.852     0.827     0.839     10000
21# weighted avg      0.860     0.862     0.860     10000

Summary

In this section, we covered health classification metrics:

  1. Health states: Healthy (RUL > 50), Degrading (15-50), Critical (≤15)
  2. Key metrics: Accuracy, precision, recall, F1 score
  3. Weighted F1: Accounts for class imbalance
  4. Critical recall: Most important for safety
  5. Confusion matrix: Reveals misclassification patterns
MetricTargetWhy Important
Accuracy>80%Overall correctness
Weighted F1>75%Class-balanced performance
Critical Recall>80%Safety (don't miss failures)
Critical Precision>70%Avoid false alarms
Looking Ahead: We now have all the individual metrics—RMSE, NASA score, and health classification. The final section brings them together in a comprehensive evaluation pipeline that provides a complete assessment of model performance.

With classification metrics covered, we build the complete evaluation pipeline.