Learning Objectives
By the end of this section, you will:
- Understand health state discretization from continuous RUL
- Compute classification metrics (accuracy, precision, recall, F1)
- Handle class imbalance in health state distribution
- Analyze per-class performance with confusion matrices
- Evaluate the secondary task in dual-task learning
Why This Matters: The AMNL model jointly predicts RUL (regression) and health state (classification). The health classification task provides regularization that improves RUL prediction. Evaluating both tasks gives a complete picture of model performance.
Health State Definition
Health states are derived from continuous RUL values using threshold discretization.
Three-Class Health State Model
| State | RUL Range | Meaning | Action |
|---|---|---|---|
| Healthy (0) | RUL > 50 | Normal operation | Continue monitoring |
| Degrading (1) | 15 < RUL ≤ 50 | Noticeable wear | Schedule maintenance |
| Critical (2) | RUL ≤ 15 | Imminent failure | Immediate intervention |
Discretization Function
1def rul_to_health_state(rul_values: np.ndarray) -> np.ndarray:
2 """
3 Convert continuous RUL to discrete health states.
4
5 Health states:
6 0: Healthy (RUL > 50)
7 1: Degrading (15 < RUL <= 50)
8 2: Critical (RUL <= 15)
9
10 Args:
11 rul_values: Array of RUL values
12
13 Returns:
14 Array of health state labels (0, 1, or 2)
15 """
16 health_states = np.zeros_like(rul_values, dtype=np.int64)
17
18 # Degrading: 15 < RUL <= 50
19 health_states[(rul_values > 15) & (rul_values <= 50)] = 1
20
21 # Critical: RUL <= 15
22 health_states[rul_values <= 15] = 2
23
24 return health_statesClass Distribution
The health state distribution varies across datasets:
| Dataset | Healthy | Degrading | Critical |
|---|---|---|---|
| FD001 | ~48% | ~32% | ~20% |
| FD002 | ~52% | ~30% | ~18% |
| FD003 | ~47% | ~33% | ~20% |
| FD004 | ~51% | ~31% | ~18% |
Class Imbalance
The critical class is underrepresented (~18-20%) compared to healthy (~48-52%). This imbalance affects metric interpretation—high overall accuracy may hide poor critical-class performance.
Classification Metrics
Standard classification metrics for multi-class health prediction.
Accuracy
For multi-class, this is the fraction of all predictions that are correct.
Precision, Recall, and F1
Per-class metrics provide deeper insight:
| Metric | Meaning | High Value Indicates |
|---|---|---|
| Precision | Of predicted class c, how many are correct | Few false alarms |
| Recall | Of actual class c, how many were found | Few missed cases |
| F1 Score | Harmonic mean of precision and recall | Balanced performance |
Weighted F1 Score
For imbalanced classes, weighted F1 accounts for class frequency:
Where is the number of samples in class c, and N is the total number of samples.
Handling Class Imbalance
Strategies for evaluating and reporting with imbalanced classes.
Per-Class Analysis
Always report metrics for each class separately:
| Class | Precision | Recall | F1 | Support |
|---|---|---|---|---|
| Healthy | 0.92 | 0.95 | 0.93 | 5,000 |
| Degrading | 0.78 | 0.72 | 0.75 | 3,000 |
| Critical | 0.85 | 0.81 | 0.83 | 2,000 |
Confusion Matrix
The confusion matrix reveals error patterns:
1Predicted
2 Healthy Degrading Critical
3Actual Healthy 4750 200 50
4 Degrading 300 2160 540
5 Critical 50 330 1620
6
7Key insights:
8- Healthy well-detected (95% recall)
9- Degrading confused with Critical (18% → Critical)
10- Critical sometimes missed (16.5% → Degrading)Critical Class Focus
For safety-critical applications, prioritize critical-class recall:
Critical Class Recall
Missing a critical health state (predicting healthy when actually critical) is dangerous. Target ≥80% critical recall even if it means lower precision (more false alarms are acceptable).
Implementation
Complete implementation for health classification evaluation.
Comprehensive Classification Metrics
1from sklearn.metrics import (
2 accuracy_score,
3 precision_score,
4 recall_score,
5 f1_score,
6 confusion_matrix,
7 classification_report
8)
9
10def evaluate_health_classification(
11 predictions: np.ndarray,
12 targets: np.ndarray,
13 class_names: List[str] = ['Healthy', 'Degrading', 'Critical']
14) -> Dict:
15 """
16 Comprehensive evaluation of health state classification.
17
18 Args:
19 predictions: Predicted health states (0, 1, 2)
20 targets: True health states (0, 1, 2)
21 class_names: Names for each class
22
23 Returns:
24 Dictionary with all classification metrics
25 """
26 results = {}
27
28 # Overall accuracy
29 results['accuracy'] = float(accuracy_score(targets, predictions) * 100)
30
31 # Weighted F1 (accounts for class imbalance)
32 results['f1_weighted'] = float(
33 f1_score(targets, predictions, average='weighted') * 100
34 )
35
36 # Macro F1 (equal weight per class)
37 results['f1_macro'] = float(
38 f1_score(targets, predictions, average='macro') * 100
39 )
40
41 # Per-class metrics
42 precision_per_class = precision_score(
43 targets, predictions, average=None, zero_division=0
44 )
45 recall_per_class = recall_score(
46 targets, predictions, average=None, zero_division=0
47 )
48 f1_per_class = f1_score(
49 targets, predictions, average=None, zero_division=0
50 )
51
52 for i, name in enumerate(class_names):
53 results[f'precision_{name.lower()}'] = float(precision_per_class[i] * 100)
54 results[f'recall_{name.lower()}'] = float(recall_per_class[i] * 100)
55 results[f'f1_{name.lower()}'] = float(f1_per_class[i] * 100)
56
57 # Confusion matrix
58 results['confusion_matrix'] = confusion_matrix(targets, predictions).tolist()
59
60 # Class distribution
61 for i, name in enumerate(class_names):
62 results[f'support_{name.lower()}'] = int(np.sum(targets == i))
63
64 return resultsHealth Evaluation in Model Training
1def evaluate_model_health_task(
2 model: nn.Module,
3 test_loader: DataLoader,
4 device: torch.device
5) -> Dict:
6 """
7 Evaluate health classification task during training.
8
9 Args:
10 model: Dual-task model
11 test_loader: Test data loader
12 device: Compute device
13
14 Returns:
15 Health classification metrics
16 """
17 model.eval()
18
19 all_health_preds = []
20 all_health_targets = []
21
22 with torch.no_grad():
23 for sequences, rul_targets in test_loader:
24 sequences = sequences.to(device)
25
26 # Get model predictions
27 rul_pred, health_logits = model(sequences)
28
29 # Convert logits to class predictions
30 health_preds = health_logits.argmax(dim=1).cpu().numpy()
31
32 # Convert RUL to health states for targets
33 health_targets = rul_to_health_state(rul_targets.numpy())
34
35 all_health_preds.extend(health_preds)
36 all_health_targets.extend(health_targets)
37
38 all_health_preds = np.array(all_health_preds)
39 all_health_targets = np.array(all_health_targets)
40
41 # Compute metrics
42 results = {
43 'health_accuracy': float(
44 accuracy_score(all_health_targets, all_health_preds) * 100
45 ),
46 'health_f1': float(
47 f1_score(all_health_targets, all_health_preds, average='weighted') * 100
48 ),
49 }
50
51 # Add critical-class recall (safety-critical)
52 critical_mask = all_health_targets == 2
53 if np.sum(critical_mask) > 0:
54 critical_correct = np.sum(
55 (all_health_preds == 2) & critical_mask
56 )
57 results['critical_recall'] = float(
58 critical_correct / np.sum(critical_mask) * 100
59 )
60
61 return resultsPrinting Classification Report
1# Generate detailed classification report
2report = classification_report(
3 all_health_targets,
4 all_health_preds,
5 target_names=['Healthy', 'Degrading', 'Critical'],
6 digits=3
7)
8
9print("Health State Classification Report:")
10print(report)
11
12# Example output:
13# precision recall f1-score support
14#
15# Healthy 0.921 0.950 0.935 5000
16# Degrading 0.784 0.720 0.751 3000
17# Critical 0.852 0.810 0.830 2000
18#
19# accuracy 0.862 10000
20# macro avg 0.852 0.827 0.839 10000
21# weighted avg 0.860 0.862 0.860 10000Summary
In this section, we covered health classification metrics:
- Health states: Healthy (RUL > 50), Degrading (15-50), Critical (≤15)
- Key metrics: Accuracy, precision, recall, F1 score
- Weighted F1: Accounts for class imbalance
- Critical recall: Most important for safety
- Confusion matrix: Reveals misclassification patterns
| Metric | Target | Why Important |
|---|---|---|
| Accuracy | >80% | Overall correctness |
| Weighted F1 | >75% | Class-balanced performance |
| Critical Recall | >80% | Safety (don't miss failures) |
| Critical Precision | >70% | Avoid false alarms |
Looking Ahead: We now have all the individual metrics—RMSE, NASA score, and health classification. The final section brings them together in a comprehensive evaluation pipeline that provides a complete assessment of model performance.
With classification metrics covered, we build the complete evaluation pipeline.