Learning Objectives
By the end of this section, you will:
- Understand multi-head attention in the AMNL architecture
- Analyze the impact of removing attention
- Interpret attention patterns in degradation modeling
- Connect attention to temporal dependencies in RUL prediction
- Implement attention ablation experiments
Key Finding: Removing multi-head attention degrades RMSE by 15-25% depending on the dataset. Attention is particularly important for complex multi-condition scenarios where the model must learn to weight different temporal patterns.
Attention Architecture
AMNL uses multi-head self-attention after the BiLSTM encoder to capture long-range temporal dependencies.
Multi-Head Attention Formulation
Given BiLSTM output where is sequence length and is hidden dimension:
Where for self-attention:
AMNL Attention Configuration
| Parameter | Value | Description |
|---|---|---|
| Number of Heads | 8 | Parallel attention mechanisms |
| Head Dimension | 64 | Hidden / Heads = 512 / 8 |
| Dropout | 0.2 | Regularization on attention weights |
| Residual Connection | Yes | H' = H + Attention(H) |
Why Attention for Degradation Modeling?
- Long-range dependencies: Early degradation patterns affect final RUL prediction
- Variable relevance: Different timesteps have different importance for prediction
- Condition adaptation: Attention can learn to weight condition-relevant features
- Interpretability: Attention weights reveal which timesteps influence predictions
Self-Attention Intuition
Self-attention allows each timestep to "attend to" all other timesteps. For degradation modeling, this means the model can directly connect the current degradation state to earlier patterns, without relying solely on the sequential LSTM hidden state.
Ablation Results
Comparing AMNL with and without the multi-head attention mechanism.
Per-Dataset Impact
| Dataset | With Attention | Without Attention | Degradation |
|---|---|---|---|
| FD001 | 10.43 | 12.15 | +16.5% |
| FD002 | 6.74 | 8.42 | +24.9% |
| FD003 | 9.51 | 11.03 | +16.0% |
| FD004 | 8.16 | 10.21 | +25.1% |
Complexity Correlation
NASA Score Impact
| Dataset | With Attention | Without Attention | Change |
|---|---|---|---|
| FD001 | 434.3 | 521.8 | +20.1% |
| FD002 | 356.0 | 467.2 | +31.2% |
| FD003 | 338.9 | 412.5 | +21.7% |
| FD004 | 537.5 | 723.8 | +34.7% |
Consistent Pattern
NASA Score degradation follows the same pattern as RMSE: multi-condition datasets (FD002, FD004) show larger degradation (~33%) compared to single-condition datasets (~21%). Attention helps reduce dangerous late predictions.
Statistical Significance
| Comparison | t-statistic | p-value | Effect Size (d) |
|---|---|---|---|
| FD001 ± Attention | 3.21 | 0.018 | 1.12 |
| FD002 ± Attention | 4.87 | 0.003 | 1.89 |
| FD003 ± Attention | 2.94 | 0.025 | 0.98 |
| FD004 ± Attention | 5.12 | 0.002 | 2.01 |
All comparisons are statistically significant (p < 0.05), with large effect sizes (d > 0.8) confirming attention provides substantial benefit.
Attention Pattern Analysis
Examining what the attention mechanism learns reveals interpretable patterns.
Attention Weight Visualization
Analyzing attention patterns across the degradation trajectory shows distinct behaviors:
| Degradation Phase | Attention Pattern | Interpretation |
|---|---|---|
| Healthy (RUL > 50) | Diffuse, spread evenly | No critical patterns to focus on |
| Degrading (15 < RUL ≤ 50) | Increasing focus on recent timesteps | Recent changes become more relevant |
| Critical (RUL ≤ 15) | Strong focus on degradation onset | Model identifies when degradation began |
Head Specialization
Different attention heads learn to focus on different aspects of the temporal pattern:
| Head Type | Pattern | Function |
|---|---|---|
| Local heads | Focus on recent 5-10 timesteps | Capture short-term dynamics |
| Global heads | Attend to full sequence | Capture long-range dependencies |
| Periodic heads | Oscillating attention | Capture cyclic patterns |
| Onset heads | Focus on degradation start | Identify transition points |
Emergent Specialization
These head specializations emerge naturally during training—they are not explicitly designed. This suggests the multi-head structure allows the model to decompose the complex degradation modeling task into interpretable sub-problems.
Condition-Specific Attention
On multi-condition datasets, attention patterns differ by operating condition:
Implementation
Code for the attention ablation experiment.
Model Without Attention
1class AMNLWithoutAttention(nn.Module):
2 """
3 AMNL model with attention mechanism removed (for ablation).
4
5 Uses only the BiLSTM output without self-attention.
6 """
7
8 def __init__(
9 self,
10 input_size: int = 17,
11 sequence_length: int = 30,
12 hidden_size: int = 256,
13 num_health_states: int = 3,
14 dropout: float = 0.2
15 ):
16 super().__init__()
17
18 # CNN Feature Extractor
19 self.cnn = nn.Sequential(
20 nn.Conv1d(input_size, 64, kernel_size=3, padding=1),
21 nn.BatchNorm1d(64),
22 nn.ReLU(),
23 nn.Conv1d(64, 128, kernel_size=3, padding=1),
24 nn.BatchNorm1d(128),
25 nn.ReLU(),
26 )
27
28 # BiLSTM Temporal Encoder
29 self.lstm = nn.LSTM(
30 input_size=128,
31 hidden_size=hidden_size,
32 num_layers=2,
33 batch_first=True,
34 bidirectional=True,
35 dropout=dropout
36 )
37
38 # NOTE: No attention mechanism!
39 # self.attention = nn.MultiheadAttention(...) # REMOVED
40
41 # Simple pooling instead of attention
42 # Option 1: Last timestep (used in ablation)
43 # Option 2: Mean pooling
44 # Option 3: Max pooling
45 self.pooling_type = 'last'
46
47 # Prediction heads
48 lstm_output_size = hidden_size * 2 # Bidirectional
49
50 self.rul_head = nn.Sequential(
51 nn.Linear(lstm_output_size, 128),
52 nn.ReLU(),
53 nn.Dropout(dropout),
54 nn.Linear(128, 1)
55 )
56
57 self.health_head = nn.Sequential(
58 nn.Linear(lstm_output_size, 64),
59 nn.ReLU(),
60 nn.Dropout(dropout),
61 nn.Linear(64, num_health_states)
62 )
63
64 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
65 # CNN: [batch, seq, features] -> [batch, features, seq]
66 x = x.transpose(1, 2)
67 x = self.cnn(x)
68 x = x.transpose(1, 2) # Back to [batch, seq, features]
69
70 # BiLSTM
71 lstm_out, _ = self.lstm(x)
72
73 # Simple pooling (no attention)
74 if self.pooling_type == 'last':
75 pooled = lstm_out[:, -1, :]
76 elif self.pooling_type == 'mean':
77 pooled = lstm_out.mean(dim=1)
78 elif self.pooling_type == 'max':
79 pooled = lstm_out.max(dim=1)[0]
80
81 # Predictions
82 rul_pred = self.rul_head(pooled)
83 health_pred = self.health_head(pooled)
84
85 return rul_pred, health_predFull AMNL with Attention (Reference)
1class AMNLWithAttention(nn.Module):
2 """
3 Full AMNL model with multi-head self-attention.
4 """
5
6 def __init__(self, ...):
7 # ... same as above ...
8
9 # Multi-Head Self-Attention
10 self.attention = nn.MultiheadAttention(
11 embed_dim=hidden_size * 2,
12 num_heads=8,
13 dropout=dropout,
14 batch_first=True
15 )
16
17 # Layer norm for residual connection
18 self.attention_norm = nn.LayerNorm(hidden_size * 2)
19
20 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
21 # ... CNN and BiLSTM same as above ...
22
23 # Multi-Head Self-Attention
24 attn_out, attn_weights = self.attention(
25 query=lstm_out,
26 key=lstm_out,
27 value=lstm_out
28 )
29
30 # Residual connection + layer norm
31 lstm_out = self.attention_norm(lstm_out + attn_out)
32
33 # Weighted sum using attention to last timestep
34 # Or: pooled = lstm_out[:, -1, :]
35 pooled = lstm_out[:, -1, :]
36
37 # ... predictions same as above ...
38
39 return rul_pred, health_predAttention Visualization
1def visualize_attention_patterns(
2 model: AMNLWithAttention,
3 test_loader: DataLoader,
4 device: torch.device,
5 num_samples: int = 10
6) -> Dict:
7 """
8 Extract and visualize attention patterns.
9
10 Returns attention weights for analysis.
11 """
12 model.eval()
13 attention_data = []
14
15 with torch.no_grad():
16 for batch_x, batch_y in test_loader:
17 batch_x = batch_x.to(device)
18
19 # Get attention weights
20 # Need to modify forward pass to return weights
21 rul_pred, health_pred, attn_weights = model.forward_with_attention(
22 batch_x
23 )
24
25 for i in range(min(len(batch_x), num_samples)):
26 attention_data.append({
27 'rul_true': batch_y[i].item(),
28 'rul_pred': rul_pred[i].item(),
29 'attention_weights': attn_weights[i].cpu().numpy(),
30 'health_state': get_health_state(batch_y[i].item())
31 })
32
33 if len(attention_data) >= num_samples:
34 break
35
36 # Analyze patterns
37 healthy_attn = np.stack([
38 d['attention_weights']
39 for d in attention_data
40 if d['health_state'] == 'healthy'
41 ])
42 critical_attn = np.stack([
43 d['attention_weights']
44 for d in attention_data
45 if d['health_state'] == 'critical'
46 ])
47
48 print("Healthy samples: attention entropy =", entropy(healthy_attn.mean(0)))
49 print("Critical samples: attention entropy =", entropy(critical_attn.mean(0)))
50
51 return {
52 'healthy_mean': healthy_attn.mean(0),
53 'critical_mean': critical_attn.mean(0),
54 'all_data': attention_data
55 }Ablation Runner
1def run_attention_ablation(
2 datasets: List[str] = ['FD002', 'FD004'],
3 seeds: List[int] = [42, 123, 456],
4 epochs: int = 300
5) -> pd.DataFrame:
6 """
7 Compare models with and without attention.
8 """
9 results = []
10
11 for dataset in datasets:
12 for seed in seeds:
13 # With attention
14 model_attn = AMNLWithAttention(use_attention=True)
15 result_attn = train_model(model_attn, dataset, seed, epochs)
16
17 results.append({
18 'dataset': dataset,
19 'seed': seed,
20 'attention': True,
21 'rmse': result_attn['rmse'],
22 'nasa_score': result_attn['nasa_score'],
23 'r2': result_attn['r2']
24 })
25
26 # Without attention
27 model_no_attn = AMNLWithoutAttention()
28 result_no_attn = train_model(model_no_attn, dataset, seed, epochs)
29
30 results.append({
31 'dataset': dataset,
32 'seed': seed,
33 'attention': False,
34 'rmse': result_no_attn['rmse'],
35 'nasa_score': result_no_attn['nasa_score'],
36 'r2': result_no_attn['r2']
37 })
38
39 df = pd.DataFrame(results)
40
41 # Statistical test
42 for dataset in datasets:
43 attn_rmse = df[(df['dataset'] == dataset) & (df['attention'])]['rmse']
44 no_attn_rmse = df[(df['dataset'] == dataset) & (~df['attention'])]['rmse']
45
46 t_stat, p_val = stats.ttest_ind(attn_rmse, no_attn_rmse)
47 effect_size = (no_attn_rmse.mean() - attn_rmse.mean()) / np.sqrt(
48 (attn_rmse.var() + no_attn_rmse.var()) / 2
49 )
50
51 print(f"{dataset}: t={t_stat:.2f}, p={p_val:.4f}, d={effect_size:.2f}")
52
53 return dfSummary
Attention Mechanism Impact Summary:
- Significant impact: 15-25% RMSE degradation without attention
- Complexity-dependent: Multi-condition datasets show ~25% degradation vs ~16% for single-condition
- Head specialization: Different heads learn local, global, periodic, and onset patterns
- Condition adaptation: Attention weights adapt to operating conditions
- Statistically significant: All comparisons show p < 0.05 with large effect sizes
| Key Metric | Value |
|---|---|
| Average RMSE degradation | 20.6% |
| Max degradation (FD004) | 25.1% |
| Min degradation (FD003) | 16.0% |
| Complexity ratio | 1.53× (complex vs simple) |
Key Insight: Multi-head attention is particularly valuable for complex multi-condition datasets. The attention mechanism learns to weight different temporal patterns and sensor combinations based on operating conditions, enabling condition-invariant degradation modeling. While less critical than the dual-task structure, attention provides a significant and consistent improvement.
With attention impact analyzed, we next examine individual architecture components.