Chapter 8
18 min read
Section 40 of 104

Complete Model Assembly

Dual Task Prediction Heads

Learning Objectives

By the end of this section, you will:

  1. Assemble the complete AMNL model from all components
  2. Implement the full model class in PyTorch
  3. Trace the complete forward pass from input to outputs
  4. Understand training vs inference mode differences
  5. Configure model initialization properly
Why This Matters: This section brings together everything we have builtβ€”CNN, BiLSTM, Attention, and dual prediction headsβ€”into a single cohesive model. Understanding how these components interact is essential for debugging, optimization, and extending the architecture.

Complete Model Overview

The AMNL model consists of an encoder (shared) and two task-specific heads.

Architecture Diagram

πŸ“text
1Input: X ∈ ℝ^(B Γ— 30 Γ— 17)
2              β”‚
3              ↓
4β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5β”‚                    ENCODER                       β”‚
6β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚
7β”‚  β”‚  CNN Feature Extractor                   β”‚    β”‚
8β”‚  β”‚  Conv1d layers: 17 β†’ 64 β†’ 128 β†’ 64      β”‚    β”‚
9β”‚  β”‚  BatchNorm + ReLU + Dropout             β”‚    β”‚
10β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚
11β”‚                      ↓                           β”‚
12β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚
13β”‚  β”‚  BiLSTM Encoder                         β”‚    β”‚
14β”‚  β”‚  2 layers, hidden=128, bidirectional    β”‚    β”‚
15β”‚  β”‚  Output: 256 (128 Γ— 2)                  β”‚    β”‚
16β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚
17β”‚                      ↓                           β”‚
18β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”    β”‚
19β”‚  β”‚  Multi-Head Attention                   β”‚    β”‚
20β”‚  β”‚  8 heads, residual + LayerNorm          β”‚    β”‚
21β”‚  β”‚  Mean pooling over sequence             β”‚    β”‚
22β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜    β”‚
23β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
24              β”‚
25              ↓
26         z ∈ ℝ^(B Γ— 256)
27              β”‚
28    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
29    ↓                   ↓
30β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
31β”‚ RUL Headβ”‚       β”‚Health   β”‚
32β”‚256β†’128β†’1β”‚       β”‚Head     β”‚
33β”‚         β”‚       β”‚256β†’64β†’3 β”‚
34β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
35    ↓                   ↓
36Ε·_RUL ∈ ℝ^B      logits ∈ ℝ^(B Γ— 3)

Component Summary

ComponentInput ShapeOutput ShapeKey Feature
CNN(B, 30, 17)(B, 30, 64)Local pattern extraction
BiLSTM(B, 30, 64)(B, 30, 256)Temporal encoding
Attention(B, 30, 256)(B, 256)Weighted aggregation
RUL Head(B, 256)(B, 1)Regression output
Health Head(B, 256)(B, 3)Classification logits

PyTorch Implementation

The complete AMNL model combines all components into a single nn.Module.

Model Class Definition

🐍python
1class AMNL(nn.Module):
2    """
3    AMNL: Adaptive Multi-task Normalized Loss Model
4
5    Complete architecture for RUL prediction with health classification.
6    Achieves state-of-the-art on all NASA C-MAPSS datasets.
7
8    Architecture:
9        Encoder: CNN β†’ BiLSTM β†’ Attention
10        Heads: RUL (regression) + Health (classification)
11    """
12
13    def __init__(
14        self,
15        input_dim: int = 17,
16        seq_len: int = 30,
17        cnn_channels: list = [64, 128, 64],
18        lstm_hidden: int = 128,
19        lstm_layers: int = 2,
20        attention_heads: int = 8,
21        rul_hidden: int = 128,
22        health_hidden: int = 64,
23        num_classes: int = 3,
24        dropout: float = 0.1
25    ):
26        super().__init__()
27
28        # Store config
29        self.input_dim = input_dim
30        self.seq_len = seq_len
31
32        # ============ ENCODER ============
33        # CNN Feature Extractor
34        self.cnn = nn.Sequential(
35            # Conv Block 1: 17 β†’ 64
36            nn.Conv1d(input_dim, cnn_channels[0], kernel_size=3, padding=1),
37            nn.BatchNorm1d(cnn_channels[0]),
38            nn.ReLU(),
39            nn.Dropout(0.2),
40
41            # Conv Block 2: 64 β†’ 128
42            nn.Conv1d(cnn_channels[0], cnn_channels[1], kernel_size=3, padding=1),
43            nn.BatchNorm1d(cnn_channels[1]),
44            nn.ReLU(),
45            nn.Dropout(0.2),
46
47            # Conv Block 3: 128 β†’ 64
48            nn.Conv1d(cnn_channels[1], cnn_channels[2], kernel_size=3, padding=1),
49            nn.BatchNorm1d(cnn_channels[2]),
50            nn.ReLU(),
51            nn.Dropout(0.2),
52        )
53
54        # BiLSTM Encoder
55        self.lstm = nn.LSTM(
56            input_size=cnn_channels[-1],  # 64
57            hidden_size=lstm_hidden,       # 128
58            num_layers=lstm_layers,        # 2
59            batch_first=True,
60            bidirectional=True,
61            dropout=0.3 if lstm_layers > 1 else 0
62        )
63        self.lstm_norm = nn.LayerNorm(lstm_hidden * 2)  # 256
64
65        # Attention Layer
66        embed_dim = lstm_hidden * 2  # 256
67        self.attention = nn.MultiheadAttention(
68            embed_dim=embed_dim,
69            num_heads=attention_heads,
70            dropout=dropout,
71            batch_first=True
72        )
73        self.attn_norm = nn.LayerNorm(embed_dim)
74        self.attn_dropout = nn.Dropout(dropout)
75
76        # ============ PREDICTION HEADS ============
77        # RUL Prediction Head
78        self.rul_head = nn.Sequential(
79            nn.Linear(embed_dim, rul_hidden),
80            nn.ReLU(),
81            nn.Dropout(0.3),
82            nn.Linear(rul_hidden, 1)
83        )
84
85        # Health Classification Head
86        self.health_head = nn.Sequential(
87            nn.Linear(embed_dim, health_hidden),
88            nn.ReLU(),
89            nn.Dropout(0.3),
90            nn.Linear(health_hidden, num_classes)
91        )
92
93        # Initialize weights
94        self._init_weights()

Weight Initialization

🐍python
1def _init_weights(self):
2    """
3    Initialize model weights for stable training.
4
5    - Linear layers: Xavier uniform
6    - LSTM: Orthogonal initialization
7    - BatchNorm: Standard (weight=1, bias=0)
8    """
9    for name, param in self.named_parameters():
10        if 'weight' in name:
11            if 'lstm' in name:
12                # Orthogonal for LSTM weights
13                if len(param.shape) >= 2:
14                    nn.init.orthogonal_(param)
15            elif 'bn' in name or 'norm' in name:
16                # BatchNorm/LayerNorm: weight=1
17                nn.init.ones_(param)
18            elif len(param.shape) >= 2:
19                # Linear layers: Xavier
20                nn.init.xavier_uniform_(param)
21        elif 'bias' in name:
22            nn.init.zeros_(param)

Why These Initializations?

Xavier/Glorot: Maintains variance across layers, preventing vanishing/exploding gradients in deep networks.

Orthogonal (LSTM): Preserves gradient magnitude through time steps, especially important for long sequences.


Complete Forward Pass

The forward method implements the complete computation graph.

Forward Method

Complete AMNL Forward Pass
🐍amnl_forward.py
1Forward Method

Main computation. Called when you do model(x). Returns both RUL prediction and health classification logits.

8Input Shape

Input is raw sensor data: batch of samples, each with 30 timesteps, each timestep has 17 sensor readings.

EXAMPLE
# Input tensor structure:
x.shape = (32, 30, 17)
# 32 samples in batch
# 30 timesteps per sample
# 17 sensor readings per timestep

# Example x[0, 0, :] = one timestep, all 17 sensors:
# tensor([0.5, -0.2, 0.8, ..., 0.1])  # 17 values
19Permute for Conv1d

Conv1d expects channels dimension second. Swap dims 1 and 2 to put sensors (17) before timesteps (30).

EXAMPLE
# BEFORE: x.shape = (32, 30, 17)
#         [batch, timesteps, sensors]

x = x.permute(0, 2, 1)

# AFTER: x.shape = (32, 17, 30)
#        [batch, sensors, timesteps]

# Now Conv1d sees each sensor as a 'channel'
# and convolves along the time dimension
20CNN Processing

Apply 3 conv blocks: 17β†’64β†’128β†’64 channels. Each block has Conv1d + BatchNorm + ReLU + Dropout.

EXAMPLE
# x.shape = (32, 17, 30)  # Input

# Conv Block 1: 17 β†’ 64 channels
# Conv Block 2: 64 β†’ 128 channels
# Conv Block 3: 128 β†’ 64 channels

# AFTER: x.shape = (32, 64, 30)
# 64 learned features per timestep
22Permute Back

Restore original dimension order for LSTM. Put timesteps before features.

EXAMPLE
# BEFORE: x.shape = (32, 64, 30)
#         [batch, features, timesteps]

x = x.permute(0, 2, 1)

# AFTER: x.shape = (32, 30, 64)
#        [batch, timesteps, features]

# Now LSTM can process sequence of 30 timesteps
# Each timestep is a 64-dim feature vector
27BiLSTM Processing

Bidirectional LSTM processes sequence forward and backward. Output is concatenation of both directions.

EXAMPLE
# BEFORE: x.shape = (32, 30, 64)

x, _ = self.lstm(x)

# AFTER: x.shape = (32, 30, 256)
# 256 = 128 (forward) + 128 (backward)

# Each timestep now has context from
# BOTH past and future timesteps
28LSTM LayerNorm

Normalize each 256-dim vector. Stabilizes activations before attention.

EXAMPLE
# x.shape = (32, 30, 256)
x = self.lstm_norm(x)
# x.shape = (32, 30, 256)  # Same shape

# Each of 32Γ—30=960 vectors normalized
# meanβ‰ˆ0, stdβ‰ˆ1 for each 256-dim vector
32Self-Attention

Each timestep attends to all timesteps. Q=K=V=x. Output has same shape as input.

EXAMPLE
# x.shape = (32, 30, 256)

attn_out, _ = self.attention(x, x, x)

# attn_out.shape = (32, 30, 256)

# Each of the 30 timesteps is now a weighted
# combination of all 30 timesteps based on
# learned attention scores
34Residual + LayerNorm

Add original input to attention output (residual), then normalize. Helps gradient flow.

EXAMPLE
# x.shape = (32, 30, 256)  # Original
# attn_out.shape = (32, 30, 256)

dropped = self.attn_dropout(attn_out)  # Dropout
added = x + dropped  # Residual connection
x = self.attn_norm(added)  # LayerNorm

# x.shape = (32, 30, 256)
36Mean Pooling

Average across all 30 timesteps to get single 256-dim vector per sample. This is the encoded representation.

EXAMPLE
# BEFORE: x.shape = (32, 30, 256)
# 30 timestep vectors per sample

z = x.mean(dim=1)  # Average over timesteps

# AFTER: z.shape = (32, 256)
# One 256-dim vector per sample

# This 'z' is the encoded representation
# fed to both prediction heads
39RUL Prediction

Pass encoded representation through RUL head: 256β†’128β†’1. Output is predicted remaining useful life.

EXAMPLE
# z.shape = (32, 256)

rul_pred = self.rul_head(z)

# rul_pred.shape = (32, 1)
# Each value is predicted RUL in cycles

# Example: rul_pred[0] = tensor([87.3])
# β†’ Predicted 87.3 cycles until failure
40Health Classification

Pass same encoded representation through health head: 256β†’64β†’3. Output is class logits.

EXAMPLE
# z.shape = (32, 256)

health_logits = self.health_head(z)

# health_logits.shape = (32, 3)
# Raw scores for each class

# Example: health_logits[0] = tensor([0.2, 2.1, -0.5])
# Highest score (2.1) β†’ Class 1 (Degrading)
42Return Outputs

Return both predictions. During training both are used; during inference often only RUL is needed.

EXAMPLE
# Returns tuple:
# rul_pred: (32, 1) - RUL predictions
# health_logits: (32, 3) - Class scores

# Usage:
rul, health = model(x)
# rul[0] = 87.3 cycles
# health[0] β†’ Class 1 (Degrading)
29 lines without explanation
1def forward(
2    self,
3    x: torch.Tensor
4) -> tuple[torch.Tensor, torch.Tensor]:
5    """
6    Forward pass through complete AMNL model.
7
8    Args:
9        x: Input sensor data (batch, seq_len, input_dim)
10           Shape: (B, 30, 17)
11
12    Returns:
13        rul_pred: RUL predictions (batch, 1)
14        health_logits: Health class logits (batch, 3)
15    """
16    batch_size = x.size(0)
17
18    # ============ CNN ============
19    # Permute for Conv1d: (B, 30, 17) β†’ (B, 17, 30)
20    x = x.permute(0, 2, 1)
21    x = self.cnn(x)
22    # Permute back: (B, 64, 30) β†’ (B, 30, 64)
23    x = x.permute(0, 2, 1)
24
25    # ============ BiLSTM ============
26    # LSTM expects (B, seq, features)
27    x, _ = self.lstm(x)  # (B, 30, 256)
28    x = self.lstm_norm(x)
29
30    # ============ Attention ============
31    # Self-attention
32    attn_out, _ = self.attention(x, x, x)
33    # Residual + LayerNorm
34    x = self.attn_norm(x + self.attn_dropout(attn_out))
35    # Mean pooling: (B, 30, 256) β†’ (B, 256)
36    z = x.mean(dim=1)
37
38    # ============ Prediction Heads ============
39    rul_pred = self.rul_head(z)           # (B, 1)
40    health_logits = self.health_head(z)   # (B, 3)
41
42    return rul_pred, health_logits

Dimension Trace


Training vs Inference Mode

The model behaves differently during training and inference.

Behavior Differences

ComponentTrainingInference
DropoutActive (random masking)Disabled (scaled pass-through)
BatchNormUses batch statisticsUses running statistics
OutputBoth heads activeOften just RUL needed

Mode Switching

🐍python
1# Training mode
2model.train()
3for batch in train_loader:
4    x, y_rul, y_health = batch
5    rul_pred, health_logits = model(x)
6    loss = compute_loss(rul_pred, y_rul, health_logits, y_health)
7    loss.backward()
8    optimizer.step()
9
10# Inference mode
11model.eval()
12with torch.no_grad():
13    for batch in test_loader:
14        x, y_rul, y_health = batch
15        rul_pred, health_logits = model(x)
16        # Predictions are deterministic
17        # No gradient computation

Inference-Only Forward

For production, we often only need RUL prediction:

🐍python
1def predict_rul(self, x: torch.Tensor) -> torch.Tensor:
2    """
3    Inference-only RUL prediction.
4
5    Args:
6        x: Input sensor data (batch, seq_len, input_dim)
7
8    Returns:
9        RUL predictions (batch,)
10    """
11    self.eval()
12    with torch.no_grad():
13        rul_pred, _ = self.forward(x)
14        # Clamp to valid range
15        rul_pred = rul_pred.squeeze(-1).clamp(min=0, max=125)
16    return rul_pred

Gradient-Free Inference

Using torch.no_grad() disables gradient computation, reducing memory usage by ~50% and speeding up inference. Always use it when you don't need backpropagation.


Summary

In this section, we assembled the complete AMNL model:

  1. Model structure: Encoder (CNN β†’ BiLSTM β†’ Attention) + dual heads
  2. Forward pass: (B, 30, 17) β†’ (B, 256) β†’ (B, 1) + (B, 3)
  3. Weight initialization: Xavier for linear, orthogonal for LSTM
  4. Mode switching: train() for training, eval() for inference
ComponentOutput ShapeKey Operation
Input(B, 30, 17)Raw sensor readings
CNN(B, 30, 64)Local feature extraction
BiLSTM(B, 30, 256)Temporal encoding
Attention(B, 256)Weighted pooling
RUL Head(B, 1)Regression prediction
Health Head(B, 3)Classification logits
Looking Ahead: The model is complete. The final section provides a detailed parameter analysis, breaking down where the 3.5M parameters are distributed and what each component contributes to the total.

With the complete model assembled, we now analyze its parameter distribution.