Learning Objectives
By the end of this section, you will:
- Assemble the complete AMNL model from all components
- Implement the full model class in PyTorch
- Trace the complete forward pass from input to outputs
- Understand training vs inference mode differences
- Configure model initialization properly
Why This Matters: This section brings together everything we have builtβCNN, BiLSTM, Attention, and dual prediction headsβinto a single cohesive model. Understanding how these components interact is essential for debugging, optimization, and extending the architecture.
Complete Model Overview
The AMNL model consists of an encoder (shared) and two task-specific heads.
Architecture Diagram
1Input: X β β^(B Γ 30 Γ 17)
2 β
3 β
4βββββββββββββββββββββββββββββββββββββββββββββββββββ
5β ENCODER β
6β βββββββββββββββββββββββββββββββββββββββββββ β
7β β CNN Feature Extractor β β
8β β Conv1d layers: 17 β 64 β 128 β 64 β β
9β β BatchNorm + ReLU + Dropout β β
10β βββββββββββββββββββββββββββββββββββββββββββ β
11β β β
12β βββββββββββββββββββββββββββββββββββββββββββ β
13β β BiLSTM Encoder β β
14β β 2 layers, hidden=128, bidirectional β β
15β β Output: 256 (128 Γ 2) β β
16β βββββββββββββββββββββββββββββββββββββββββββ β
17β β β
18β βββββββββββββββββββββββββββββββββββββββββββ β
19β β Multi-Head Attention β β
20β β 8 heads, residual + LayerNorm β β
21β β Mean pooling over sequence β β
22β βββββββββββββββββββββββββββββββββββββββββββ β
23βββββββββββββββββββββββββββββββββββββββββββββββββββ
24 β
25 β
26 z β β^(B Γ 256)
27 β
28 βββββββββββ΄ββββββββββ
29 β β
30βββββββββββ βββββββββββ
31β RUL Headβ βHealth β
32β256β128β1β βHead β
33β β β256β64β3 β
34βββββββββββ βββββββββββ
35 β β
36Ε·_RUL β β^B logits β β^(B Γ 3)Component Summary
| Component | Input Shape | Output Shape | Key Feature |
|---|---|---|---|
| CNN | (B, 30, 17) | (B, 30, 64) | Local pattern extraction |
| BiLSTM | (B, 30, 64) | (B, 30, 256) | Temporal encoding |
| Attention | (B, 30, 256) | (B, 256) | Weighted aggregation |
| RUL Head | (B, 256) | (B, 1) | Regression output |
| Health Head | (B, 256) | (B, 3) | Classification logits |
PyTorch Implementation
The complete AMNL model combines all components into a single nn.Module.
Model Class Definition
1class AMNL(nn.Module):
2 """
3 AMNL: Adaptive Multi-task Normalized Loss Model
4
5 Complete architecture for RUL prediction with health classification.
6 Achieves state-of-the-art on all NASA C-MAPSS datasets.
7
8 Architecture:
9 Encoder: CNN β BiLSTM β Attention
10 Heads: RUL (regression) + Health (classification)
11 """
12
13 def __init__(
14 self,
15 input_dim: int = 17,
16 seq_len: int = 30,
17 cnn_channels: list = [64, 128, 64],
18 lstm_hidden: int = 128,
19 lstm_layers: int = 2,
20 attention_heads: int = 8,
21 rul_hidden: int = 128,
22 health_hidden: int = 64,
23 num_classes: int = 3,
24 dropout: float = 0.1
25 ):
26 super().__init__()
27
28 # Store config
29 self.input_dim = input_dim
30 self.seq_len = seq_len
31
32 # ============ ENCODER ============
33 # CNN Feature Extractor
34 self.cnn = nn.Sequential(
35 # Conv Block 1: 17 β 64
36 nn.Conv1d(input_dim, cnn_channels[0], kernel_size=3, padding=1),
37 nn.BatchNorm1d(cnn_channels[0]),
38 nn.ReLU(),
39 nn.Dropout(0.2),
40
41 # Conv Block 2: 64 β 128
42 nn.Conv1d(cnn_channels[0], cnn_channels[1], kernel_size=3, padding=1),
43 nn.BatchNorm1d(cnn_channels[1]),
44 nn.ReLU(),
45 nn.Dropout(0.2),
46
47 # Conv Block 3: 128 β 64
48 nn.Conv1d(cnn_channels[1], cnn_channels[2], kernel_size=3, padding=1),
49 nn.BatchNorm1d(cnn_channels[2]),
50 nn.ReLU(),
51 nn.Dropout(0.2),
52 )
53
54 # BiLSTM Encoder
55 self.lstm = nn.LSTM(
56 input_size=cnn_channels[-1], # 64
57 hidden_size=lstm_hidden, # 128
58 num_layers=lstm_layers, # 2
59 batch_first=True,
60 bidirectional=True,
61 dropout=0.3 if lstm_layers > 1 else 0
62 )
63 self.lstm_norm = nn.LayerNorm(lstm_hidden * 2) # 256
64
65 # Attention Layer
66 embed_dim = lstm_hidden * 2 # 256
67 self.attention = nn.MultiheadAttention(
68 embed_dim=embed_dim,
69 num_heads=attention_heads,
70 dropout=dropout,
71 batch_first=True
72 )
73 self.attn_norm = nn.LayerNorm(embed_dim)
74 self.attn_dropout = nn.Dropout(dropout)
75
76 # ============ PREDICTION HEADS ============
77 # RUL Prediction Head
78 self.rul_head = nn.Sequential(
79 nn.Linear(embed_dim, rul_hidden),
80 nn.ReLU(),
81 nn.Dropout(0.3),
82 nn.Linear(rul_hidden, 1)
83 )
84
85 # Health Classification Head
86 self.health_head = nn.Sequential(
87 nn.Linear(embed_dim, health_hidden),
88 nn.ReLU(),
89 nn.Dropout(0.3),
90 nn.Linear(health_hidden, num_classes)
91 )
92
93 # Initialize weights
94 self._init_weights()Weight Initialization
1def _init_weights(self):
2 """
3 Initialize model weights for stable training.
4
5 - Linear layers: Xavier uniform
6 - LSTM: Orthogonal initialization
7 - BatchNorm: Standard (weight=1, bias=0)
8 """
9 for name, param in self.named_parameters():
10 if 'weight' in name:
11 if 'lstm' in name:
12 # Orthogonal for LSTM weights
13 if len(param.shape) >= 2:
14 nn.init.orthogonal_(param)
15 elif 'bn' in name or 'norm' in name:
16 # BatchNorm/LayerNorm: weight=1
17 nn.init.ones_(param)
18 elif len(param.shape) >= 2:
19 # Linear layers: Xavier
20 nn.init.xavier_uniform_(param)
21 elif 'bias' in name:
22 nn.init.zeros_(param)Why These Initializations?
Xavier/Glorot: Maintains variance across layers, preventing vanishing/exploding gradients in deep networks.
Orthogonal (LSTM): Preserves gradient magnitude through time steps, especially important for long sequences.
Complete Forward Pass
The forward method implements the complete computation graph.
Forward Method
Dimension Trace
Training vs Inference Mode
The model behaves differently during training and inference.
Behavior Differences
| Component | Training | Inference |
|---|---|---|
| Dropout | Active (random masking) | Disabled (scaled pass-through) |
| BatchNorm | Uses batch statistics | Uses running statistics |
| Output | Both heads active | Often just RUL needed |
Mode Switching
1# Training mode
2model.train()
3for batch in train_loader:
4 x, y_rul, y_health = batch
5 rul_pred, health_logits = model(x)
6 loss = compute_loss(rul_pred, y_rul, health_logits, y_health)
7 loss.backward()
8 optimizer.step()
9
10# Inference mode
11model.eval()
12with torch.no_grad():
13 for batch in test_loader:
14 x, y_rul, y_health = batch
15 rul_pred, health_logits = model(x)
16 # Predictions are deterministic
17 # No gradient computationInference-Only Forward
For production, we often only need RUL prediction:
1def predict_rul(self, x: torch.Tensor) -> torch.Tensor:
2 """
3 Inference-only RUL prediction.
4
5 Args:
6 x: Input sensor data (batch, seq_len, input_dim)
7
8 Returns:
9 RUL predictions (batch,)
10 """
11 self.eval()
12 with torch.no_grad():
13 rul_pred, _ = self.forward(x)
14 # Clamp to valid range
15 rul_pred = rul_pred.squeeze(-1).clamp(min=0, max=125)
16 return rul_predGradient-Free Inference
Using torch.no_grad() disables gradient computation, reducing memory usage by ~50% and speeding up inference. Always use it when you don't need backpropagation.
Summary
In this section, we assembled the complete AMNL model:
- Model structure: Encoder (CNN β BiLSTM β Attention) + dual heads
- Forward pass: (B, 30, 17) β (B, 256) β (B, 1) + (B, 3)
- Weight initialization: Xavier for linear, orthogonal for LSTM
- Mode switching: train() for training, eval() for inference
| Component | Output Shape | Key Operation |
|---|---|---|
| Input | (B, 30, 17) | Raw sensor readings |
| CNN | (B, 30, 64) | Local feature extraction |
| BiLSTM | (B, 30, 256) | Temporal encoding |
| Attention | (B, 256) | Weighted pooling |
| RUL Head | (B, 1) | Regression prediction |
| Health Head | (B, 3) | Classification logits |
Looking Ahead: The model is complete. The final section provides a detailed parameter analysis, breaking down where the 3.5M parameters are distributed and what each component contributes to the total.
With the complete model assembled, we now analyze its parameter distribution.