Chapter 7
12 min read
Section 36 of 104

PyTorch Implementation

Multi-Head Self-Attention

Learning Objectives

By the end of this section, you will:

  1. Implement the complete attention module using PyTorch's MultiheadAttention
  2. Integrate attention with the BiLSTM encoder
  3. Trace the forward pass through all components
  4. Verify parameter counts match our calculations
  5. Understand the complete feature extraction pipeline
Why This Matters: This section synthesizes all attention concepts into working code. PyTorch's MultiheadAttention encapsulates the Q, K, V projections, scaled dot-product attention, and output projection in a single optimized module. Understanding how to use it correctly is essential for building Transformer-based architectures.

Complete Attention Module

We wrap PyTorch's MultiheadAttention with layer normalization, residual connection, and mean pooling for RUL prediction.

Module Definition

🐍python
1class AttentionModule(nn.Module):
2    """
3    Multi-head self-attention with residual connection and layer norm.
4
5    Architecture:
6        Input β†’ MHA β†’ Dropout β†’ Add(residual) β†’ LayerNorm β†’ MeanPool β†’ Output
7
8    Args:
9        embed_dim: Input/output dimension (256 from BiLSTM)
10        num_heads: Number of attention heads (8)
11        dropout: Dropout rate for attention and residual (0.1)
12    """
13
14    def __init__(
15        self,
16        embed_dim: int = 256,
17        num_heads: int = 8,
18        dropout: float = 0.1
19    ):
20        super().__init__()
21
22        # Multi-head attention
23        # embed_dim must be divisible by num_heads
24        # Each head dimension: 256 / 8 = 32
25        self.attention = nn.MultiheadAttention(
26            embed_dim=embed_dim,
27            num_heads=num_heads,
28            dropout=dropout,
29            batch_first=True  # Input: (batch, seq, embed)
30        )
31
32        # Layer normalization (post-norm)
33        self.layer_norm = nn.LayerNorm(embed_dim)
34
35        # Dropout for residual path
36        self.dropout = nn.Dropout(dropout)

Key Parameters Explained

ParameterValuePurpose
embed_dim256Matches BiLSTM output dimension
num_heads8Parallel attention mechanisms
dropout0.1Regularization during training
batch_firstTrueInput shape (B, T, D) not (T, B, D)

batch_first Parameter

PyTorch's MultiheadAttention defaults to (seq, batch, embed) ordering. Setting batch_first=True aligns with our CNN and LSTM outputs which use (batch, seq, embed) format.

Forward Method

Attention Forward Pass
🐍attention_forward.py
1Forward Method

Called automatically when you do model(x). Takes BiLSTM output and returns a pooled representation.

6Input Shape

Input tensor from BiLSTM has shape (batch, sequence_length, embedding_dim). Each sample has 30 timesteps with 256-dim vectors.

EXAMPLE
# BEFORE: Input from BiLSTM
x.shape = (32, 30, 256)
# 32 samples Γ— 30 timesteps Γ— 256 features

# Visualize one sample:
x[0] = tensor([
    [0.1, -0.2, ..., 0.3],  # timestep 0 (256 values)
    [0.4,  0.1, ..., -0.1], # timestep 1 (256 values)
    ...                      # ... 28 more timesteps
    [0.2, -0.5, ..., 0.7],  # timestep 29 (256 values)
])
16Self-Attention

Query=Key=Value=x means each position attends to all positions (including itself). This is 'self' attention.

EXAMPLE
# Self-attention: Q = K = V = x
# Each timestep can attend to every other timestep

# BEFORE: x.shape = (32, 30, 256)
attn_output, attn_weights = self.attention(
    query=x,   # What am I looking for? (32, 30, 256)
    key=x,     # What do I match against? (32, 30, 256)
    value=x,   # What do I retrieve? (32, 30, 256)
)
# AFTER: attn_output.shape = (32, 30, 256)
#        attn_weights.shape = (32, 30, 30)

# attn_weights[b, i, j] = how much timestep i
#                          attends to timestep j
24Residual Connection

Add the original input to the attention output. This helps gradient flow and allows the model to learn 'corrections' rather than full transformations.

EXAMPLE
# BEFORE:
x.shape = (32, 30, 256)  # Original input
attn_output.shape = (32, 30, 256)  # Attention result

# Dropout randomly zeros some values (training only)
dropped = self.dropout(attn_output)

# Element-wise addition
x = x + dropped
# AFTER: x.shape = (32, 30, 256)

# Example for one element:
# x[0,0,0] = 0.5 (original)
# attn_output[0,0,0] = 0.2 (attention)
# after dropout: 0.2 (or 0 if dropped)
# result: 0.5 + 0.2 = 0.7
28Layer Normalization

Normalize each 256-dim vector to have meanβ‰ˆ0 and stdβ‰ˆ1. Stabilizes training by keeping activations in a consistent range.

EXAMPLE
# BEFORE: x.shape = (32, 30, 256)
# Each of the 32Γ—30=960 vectors has arbitrary mean/std

# LayerNorm normalizes EACH 256-dim vector independently
x = self.layer_norm(x)

# AFTER: x.shape = (32, 30, 256)
# Now each vector has meanβ‰ˆ0, stdβ‰ˆ1

# Example for x[0, 0, :] (one 256-dim vector):
# Before: mean=0.7, std=2.3
# After:  meanβ‰ˆ0.0, stdβ‰ˆ1.0
32Mean Pooling

Average across all 30 timesteps to get a single vector per sample. This aggregates information from the entire sequence.

EXAMPLE
# BEFORE: x.shape = (32, 30, 256)
# 32 samples, each with 30 timesteps of 256 features

x = x.mean(dim=1)  # Average over dimension 1 (timesteps)

# AFTER: x.shape = (32, 256)
# 32 samples, each with a single 256-dim vector

# What mean(dim=1) does:
# For each of the 256 features, average across 30 timesteps
# x[b, :, f].mean() for each feature f

# Visualization:
# Before: [[t0], [t1], ..., [t29]]  # 30 vectors
# After:  [average of t0-t29]       # 1 vector
34Return Output

Return the pooled 256-dim representation. This will be fed to the prediction heads for RUL and health classification.

EXAMPLE
# Final output shape: (32, 256)
# Each sample is now a single 256-dimensional vector
# Ready for: RUL head or Health classification head

return x  # (batch_size, 256)
28 lines without explanation
1def forward(self, x: torch.Tensor) -> torch.Tensor:
2    """
3    Forward pass with self-attention.
4
5    Args:
6        x: BiLSTM output (batch, seq_len, embed_dim)
7           Shape: (B, 30, 256)
8
9    Returns:
10        Pooled attention output (batch, embed_dim)
11        Shape: (B, 256)
12    """
13    # Self-attention: Q = K = V = x
14    # attn_output: (B, 30, 256)
15    # attn_weights: (B, 30, 30) - optional, for visualization
16    attn_output, attn_weights = self.attention(
17        query=x,
18        key=x,
19        value=x,
20        need_weights=False  # Set True for attention visualization
21    )
22
23    # Residual connection with dropout
24    # y = x + Dropout(MHA(x))
25    x = x + self.dropout(attn_output)
26
27    # Layer normalization (post-norm)
28    # y = LayerNorm(x + Sublayer(x))
29    x = self.layer_norm(x)
30
31    # Mean pooling over sequence dimension
32    # (B, 30, 256) β†’ (B, 256)
33    x = x.mean(dim=1)
34
35    return x

Dimension Flow

πŸ“text
1Input x: (B, 30, 256)
2         ↓
3MultiheadAttention(Q=x, K=x, V=x)
4         ↓
5attn_output: (B, 30, 256)
6         ↓
7Dropout(attn_output)
8         ↓
9Add: x + Dropout(attn_output)
10         ↓
11x: (B, 30, 256)
12         ↓
13LayerNorm(x)
14         ↓
15x: (B, 30, 256)
16         ↓
17Mean over dim=1
18         ↓
19Output: (B, 256)

Integration with BiLSTM

The attention module receives the BiLSTM output and produces a single 256-dimensional vector for each sample.

Complete Encoder Architecture

🐍python
1class AMNL_Encoder(nn.Module):
2    """
3    Complete encoder: CNN β†’ BiLSTM β†’ Attention β†’ Output
4
5    This combines all feature extraction components:
6    1. CNN: Extract local patterns (17 sensors β†’ 64 features)
7    2. BiLSTM: Encode temporal dependencies (64 β†’ 256)
8    3. Attention: Focus on informative timesteps (256 β†’ 256)
9    """
10
11    def __init__(
12        self,
13        input_dim: int = 17,
14        cnn_channels: list = [64, 128, 64],
15        lstm_hidden: int = 128,
16        lstm_layers: int = 2,
17        attention_heads: int = 8,
18        dropout: float = 0.1
19    ):
20        super().__init__()
21
22        # CNN Feature Extractor (from Chapter 5)
23        self.cnn = CNN_FeatureExtractor(
24            in_channels=input_dim,
25            channels=cnn_channels,
26            dropout=0.2
27        )
28
29        # BiLSTM Encoder (from Chapter 6)
30        self.bilstm = BiLSTM_Encoder(
31            input_size=cnn_channels[-1],  # 64
32            hidden_size=lstm_hidden,       # 128
33            num_layers=lstm_layers,        # 2
34            dropout=0.3
35        )
36
37        # Attention Module (this chapter)
38        self.attention = AttentionModule(
39            embed_dim=lstm_hidden * 2,  # 256 (bidirectional)
40            num_heads=attention_heads,   # 8
41            dropout=dropout              # 0.1
42        )
43
44    def forward(self, x: torch.Tensor) -> torch.Tensor:
45        """
46        Forward pass through complete encoder.
47
48        Args:
49            x: Raw sensor data (batch, seq_len, sensors)
50               Shape: (B, 30, 17)
51
52        Returns:
53            Encoded representation (batch, 256)
54        """
55        # CNN: (B, 30, 17) β†’ (B, 30, 64)
56        x = self.cnn(x)
57
58        # BiLSTM: (B, 30, 64) β†’ (B, 30, 256)
59        x = self.bilstm(x)
60
61        # Attention: (B, 30, 256) β†’ (B, 256)
62        x = self.attention(x)
63
64        return x

Module Connections

StageInput ShapeOutput ShapeKey Operation
CNN(B, 30, 17)(B, 30, 64)Local pattern extraction
BiLSTM(B, 30, 64)(B, 30, 256)Temporal encoding
Attention(B, 30, 256)(B, 256)Weighted aggregation

Dimension Compatibility

Each module's output dimension must match the next module's input. CNN outputs 64 (matching BiLSTM input_size), BiLSTM outputs 256 (matching attention embed_dim), and attention outputs 256 (ready for the prediction head).


Forward Pass Flow

Let us trace a concrete example through the complete encoder.

Attention Weight Visualization

To understand what the model focuses on, we can extract attention weights:

🐍python
1def forward_with_attention(self, x: torch.Tensor):
2    """
3    Forward pass returning attention weights for visualization.
4    """
5    # Self-attention with weights
6    attn_output, attn_weights = self.attention.attention(
7        query=x, key=x, value=x,
8        need_weights=True,
9        average_attn_weights=True  # Average across heads
10    )
11    # attn_weights: (B, 30, 30)
12    # Each row i shows how position i attends to all positions
13
14    # Complete forward pass
15    x = x + self.attention.dropout(attn_output)
16    x = self.attention.layer_norm(x)
17    x = x.mean(dim=1)
18
19    return x, attn_weights

Interpreting Attention Weights

Row i of the attention matrix shows which timesteps position i attends to. For RUL prediction, we often see higher attention on later timesteps (degradation signals) and transition points.


Parameter Verification

Let us verify our parameter calculations match the implementation.

Attention Module Parameters

🐍python
1def count_attention_params(model):
2    """Count parameters in attention module."""
3    params = {}
4
5    for name, p in model.attention.named_parameters():
6        params[name] = p.numel()
7
8    for name, p in model.layer_norm.named_parameters():
9        params[name] = p.numel()
10
11    return params
12
13# Create module
14attn = AttentionModule(embed_dim=256, num_heads=8, dropout=0.1)
15params = count_attention_params(attn)
16
17# Expected output:
18# in_proj_weight: 196608 (3 Γ— 256 Γ— 256 for Q, K, V)
19# in_proj_bias: 768 (3 Γ— 256)
20# out_proj.weight: 65536 (256 Γ— 256)
21# out_proj.bias: 256
22# layer_norm.weight: 256
23# layer_norm.bias: 256

Parameter Breakdown

ComponentShapeParameters
in_proj_weight (Q,K,V)(768, 256)196,608
in_proj_bias(768,)768
out_proj.weight(256, 256)65,536
out_proj.bias(256,)256
layer_norm.weight(256,)256
layer_norm.bias(256,)256
Total263,680

in_proj Packing

PyTorch packs the Q, K, V projections into a single in_proj matrix for efficiency. The shape (768, 256) = (3 Γ— 256, 256) represents all three projections combined.

Complete Encoder Parameters


Summary

In this section, we implemented the complete attention module:

  1. AttentionModule: Wraps MultiheadAttention with residual and layer norm
  2. Integration: CNN β†’ BiLSTM β†’ Attention pipeline
  3. Forward pass: (B, 30, 17) β†’ (B, 256) encoding
  4. Parameters: ~264K for attention, ~909K total encoder
ComponentInputOutputParameters
CNN(B, 30, 17)(B, 30, 64)~53K
BiLSTM(B, 30, 64)(B, 30, 256)~592K
Attention(B, 30, 256)(B, 256)~264K
Total Encoder(B, 30, 17)(B, 256)~909K
Chapter Complete: We have designed and implemented the complete multi-head self-attention mechanism. The encoder now transforms raw sensor sequences into rich 256-dimensional representations, with attention focusing on the most predictive timesteps. The next chapter introduces the prediction head and loss function for RUL regression.

With the attention-enhanced encoder complete, we now design the output layer for RUL prediction.