Learning Objectives
By the end of this section, you will:
- Review weight decay as a regularization technique
- Apply layer-specific weight decay for different parameter types
- Understand weight decay scheduling strategies
- Implement adaptive weight decay in PyTorch
- Choose appropriate decay values for RUL prediction
Why This Matters: Weight decay prevents overfitting by penalizing large weights. However, applying the same decay to all parameters is suboptimal—biases, normalization parameters, and embeddings often benefit from zero or reduced decay. Adaptive strategies can improve both training stability and final performance.
Weight Decay Review
Weight decay adds a penalty proportional to the squared norm of weights.
Mathematical Formulation
The regularized loss:
In AdamW, weight decay is applied directly to parameters:
Effect on Weights
Each update shrinks weights by a factor of :
Regularization Trade-off
| λ Value | Effect | Risk |
|---|---|---|
| 0 | No regularization | Overfitting |
| 1e-5 | Light regularization | Minimal constraint |
| 1e-4 | Standard regularization | Good balance |
| 1e-3 | Strong regularization | Underfitting risk |
| 1e-2 | Very strong | Severe underfitting |
Layer-Specific Weight Decay
Not all parameters benefit equally from weight decay.
Parameters That Should NOT Have Decay
| Parameter Type | Why No Decay | Example |
|---|---|---|
| Bias terms | Already low-dimensional | Linear.bias |
| LayerNorm weights | Learned scales, not weights | LayerNorm.weight |
| LayerNorm biases | Shift parameters | LayerNorm.bias |
| BatchNorm parameters | Normalization statistics | BatchNorm.* |
| Embeddings | Discrete lookups, different dynamics | Embedding.weight |
Why Biases Should Not Decay
Weight Decay Scheduling
Weight decay can be adjusted during training.
Constant Weight Decay (Default)
Simple and effective for most cases. We use this for AMNL.
Scaled with Learning Rate
Some practitioners scale weight decay with learning rate:
This keeps the relative regularization strength constant as LR decays. However, this is not recommended with AdamW—the decoupled weight decay is designed to work with constant λ.
Warmup Weight Decay
Start with lower weight decay and increase:
1Weight Decay Warmup:
2
3λ
4 λ_max ─┤ ────────────────────
5 │ ╱
6 │ ╱
7 │ ╱
8 │ ╱
9 λ_min ─┤───╱───────────────────────────────
10 └──┬──────┬─────────────────────────
11 0 T_warmup EpochsThis allows early training to focus on fitting the data before regularization kicks in. Less commonly used but can help in some cases.
Our Recommendation
For RUL prediction with AMNL, use constant weight decay (λ = 1e-4) applied only to weight matrices, not biases or normalization layers. This simple approach works well and avoids additional hyperparameters.
Implementation
PyTorch implementation of layer-specific weight decay.
Separating Parameter Groups
1def configure_optimizer_with_layer_decay(
2 model: nn.Module,
3 learning_rate: float = 1e-3,
4 weight_decay: float = 1e-4
5) -> torch.optim.AdamW:
6 """
7 Configure AdamW with layer-specific weight decay.
8
9 Applies weight decay only to weight matrices, not to:
10 - Bias terms
11 - LayerNorm/BatchNorm parameters
12 - Embedding weights (optional)
13
14 Args:
15 model: Neural network model
16 learning_rate: Learning rate
17 weight_decay: Weight decay for weight matrices
18
19 Returns:
20 Configured AdamW optimizer
21 """
22 # Keywords for parameters that should not have weight decay
23 no_decay_keywords = [
24 'bias',
25 'LayerNorm',
26 'layer_norm',
27 'BatchNorm',
28 'batch_norm',
29 'ln_',
30 'bn_',
31 ]
32
33 # Separate parameters into decay and no-decay groups
34 decay_params = []
35 no_decay_params = []
36
37 for name, param in model.named_parameters():
38 if not param.requires_grad:
39 continue
40
41 # Check if this parameter should have no decay
42 if any(kw in name for kw in no_decay_keywords):
43 no_decay_params.append(param)
44 else:
45 decay_params.append(param)
46
47 # Create parameter groups
48 param_groups = [
49 {
50 'params': decay_params,
51 'weight_decay': weight_decay,
52 'name': 'decay'
53 },
54 {
55 'params': no_decay_params,
56 'weight_decay': 0.0,
57 'name': 'no_decay'
58 }
59 ]
60
61 # Log parameter counts
62 n_decay = sum(p.numel() for p in decay_params)
63 n_no_decay = sum(p.numel() for p in no_decay_params)
64 print(f"Parameters with decay: {n_decay:,}")
65 print(f"Parameters without decay: {n_no_decay:,}")
66
67 optimizer = torch.optim.AdamW(
68 param_groups,
69 lr=learning_rate,
70 betas=(0.9, 0.999),
71 eps=1e-8
72 )
73
74 return optimizerAdvanced: Per-Layer Weight Decay
1def configure_layerwise_decay(
2 model: nn.Module,
3 base_lr: float = 1e-3,
4 base_decay: float = 1e-4,
5 decay_rate: float = 0.9
6) -> torch.optim.AdamW:
7 """
8 Configure optimizer with layer-wise learning rate and weight decay.
9
10 Deeper layers get lower learning rates and weight decay.
11 Useful for fine-tuning pretrained models.
12
13 Args:
14 model: Neural network model
15 base_lr: Learning rate for top layer
16 base_decay: Weight decay for top layer
17 decay_rate: Multiplicative factor per layer (< 1 for decay)
18
19 Returns:
20 Configured optimizer
21 """
22 # Get all named modules in order (shallow to deep)
23 named_modules = list(model.named_modules())
24
25 param_groups = []
26 seen_params = set()
27
28 # Reverse order: deepest layers first (highest LR)
29 for layer_idx, (name, module) in enumerate(reversed(named_modules)):
30 if not list(module.parameters(recurse=False)):
31 continue
32
33 layer_lr = base_lr * (decay_rate ** layer_idx)
34 layer_decay = base_decay * (decay_rate ** layer_idx)
35
36 for param_name, param in module.named_parameters(recurse=False):
37 if id(param) in seen_params:
38 continue
39 seen_params.add(id(param))
40
41 # No decay for biases
42 wd = 0.0 if 'bias' in param_name else layer_decay
43
44 param_groups.append({
45 'params': [param],
46 'lr': layer_lr,
47 'weight_decay': wd,
48 'name': f'{name}.{param_name}'
49 })
50
51 return torch.optim.AdamW(param_groups)Usage Example
1# Simple layer-specific decay (recommended)
2optimizer = configure_optimizer_with_layer_decay(
3 model=model,
4 learning_rate=1e-3,
5 weight_decay=1e-4
6)
7
8# Verify configuration
9for i, group in enumerate(optimizer.param_groups):
10 print(f"Group {i}: {group.get('name', 'unnamed')}")
11 print(f" Weight decay: {group['weight_decay']}")
12 print(f" Params: {sum(p.numel() for p in group['params']):,}")Summary
In this section, we covered adaptive weight decay:
- Weight decay: Regularization via parameter shrinkage
- No decay for biases: Biases don't contribute to overfitting
- No decay for norms: LayerNorm/BatchNorm parameters excluded
- Constant λ: Recommended for AdamW (no scheduling)
- Parameter groups: PyTorch mechanism for per-group settings
| Parameter Type | Weight Decay |
|---|---|
| Weight matrices | 1e-4 |
| Bias terms | 0 |
| LayerNorm parameters | 0 |
| BatchNorm parameters | 0 |
| Attention weights | 1e-4 |
Looking Ahead: We have configured regularization via weight decay. The final section covers gradient clipping—a technique for preventing exploding gradients that can destabilize training.
With weight decay configured, we address gradient stability.