Chapter 12
12 min read
Section 61 of 104

Adaptive Weight Decay

Optimization Strategy

Learning Objectives

By the end of this section, you will:

  1. Review weight decay as a regularization technique
  2. Apply layer-specific weight decay for different parameter types
  3. Understand weight decay scheduling strategies
  4. Implement adaptive weight decay in PyTorch
  5. Choose appropriate decay values for RUL prediction
Why This Matters: Weight decay prevents overfitting by penalizing large weights. However, applying the same decay to all parameters is suboptimal—biases, normalization parameters, and embeddings often benefit from zero or reduced decay. Adaptive strategies can improve both training stability and final performance.

Weight Decay Review

Weight decay adds a penalty proportional to the squared norm of weights.

Mathematical Formulation

The regularized loss:

Ltotal=L+λ2θ22\mathcal{L}_{\text{total}} = \mathcal{L} + \frac{\lambda}{2}\|\theta\|_2^2

In AdamW, weight decay is applied directly to parameters:

θt+1=θtηadam_updateηλθt\theta_{t+1} = \theta_t - \eta \cdot \text{adam\_update} - \eta \lambda \theta_t

Effect on Weights

Each update shrinks weights by a factor of (1ηλ)(1 - \eta\lambda):

Regularization Trade-off

λ ValueEffectRisk
0No regularizationOverfitting
1e-5Light regularizationMinimal constraint
1e-4Standard regularizationGood balance
1e-3Strong regularizationUnderfitting risk
1e-2Very strongSevere underfitting

Layer-Specific Weight Decay

Not all parameters benefit equally from weight decay.

Parameters That Should NOT Have Decay

Parameter TypeWhy No DecayExample
Bias termsAlready low-dimensionalLinear.bias
LayerNorm weightsLearned scales, not weightsLayerNorm.weight
LayerNorm biasesShift parametersLayerNorm.bias
BatchNorm parametersNormalization statisticsBatchNorm.*
EmbeddingsDiscrete lookups, different dynamicsEmbedding.weight

Why Biases Should Not Decay


Weight Decay Scheduling

Weight decay can be adjusted during training.

Constant Weight Decay (Default)

λt=λ(constant throughout training)\lambda_t = \lambda \quad \text{(constant throughout training)}

Simple and effective for most cases. We use this for AMNL.

Scaled with Learning Rate

Some practitioners scale weight decay with learning rate:

λt=ληtη0\lambda_t = \lambda \cdot \frac{\eta_t}{\eta_0}

This keeps the relative regularization strength constant as LR decays. However, this is not recommended with AdamW—the decoupled weight decay is designed to work with constant λ.

Warmup Weight Decay

Start with lower weight decay and increase:

📝text
1Weight Decay Warmup:
2
3λ
4  λ_max ─┤              ────────────────────
5         │           ╱
6         │         ╱
7         │       ╱
8         │     ╱
9  λ_min ─┤───╱───────────────────────────────
10         └──┬──────┬─────────────────────────
11            0   T_warmup              Epochs

This allows early training to focus on fitting the data before regularization kicks in. Less commonly used but can help in some cases.

Our Recommendation

For RUL prediction with AMNL, use constant weight decay (λ = 1e-4) applied only to weight matrices, not biases or normalization layers. This simple approach works well and avoids additional hyperparameters.


Implementation

PyTorch implementation of layer-specific weight decay.

Separating Parameter Groups

🐍python
1def configure_optimizer_with_layer_decay(
2    model: nn.Module,
3    learning_rate: float = 1e-3,
4    weight_decay: float = 1e-4
5) -> torch.optim.AdamW:
6    """
7    Configure AdamW with layer-specific weight decay.
8
9    Applies weight decay only to weight matrices, not to:
10    - Bias terms
11    - LayerNorm/BatchNorm parameters
12    - Embedding weights (optional)
13
14    Args:
15        model: Neural network model
16        learning_rate: Learning rate
17        weight_decay: Weight decay for weight matrices
18
19    Returns:
20        Configured AdamW optimizer
21    """
22    # Keywords for parameters that should not have weight decay
23    no_decay_keywords = [
24        'bias',
25        'LayerNorm',
26        'layer_norm',
27        'BatchNorm',
28        'batch_norm',
29        'ln_',
30        'bn_',
31    ]
32
33    # Separate parameters into decay and no-decay groups
34    decay_params = []
35    no_decay_params = []
36
37    for name, param in model.named_parameters():
38        if not param.requires_grad:
39            continue
40
41        # Check if this parameter should have no decay
42        if any(kw in name for kw in no_decay_keywords):
43            no_decay_params.append(param)
44        else:
45            decay_params.append(param)
46
47    # Create parameter groups
48    param_groups = [
49        {
50            'params': decay_params,
51            'weight_decay': weight_decay,
52            'name': 'decay'
53        },
54        {
55            'params': no_decay_params,
56            'weight_decay': 0.0,
57            'name': 'no_decay'
58        }
59    ]
60
61    # Log parameter counts
62    n_decay = sum(p.numel() for p in decay_params)
63    n_no_decay = sum(p.numel() for p in no_decay_params)
64    print(f"Parameters with decay: {n_decay:,}")
65    print(f"Parameters without decay: {n_no_decay:,}")
66
67    optimizer = torch.optim.AdamW(
68        param_groups,
69        lr=learning_rate,
70        betas=(0.9, 0.999),
71        eps=1e-8
72    )
73
74    return optimizer

Advanced: Per-Layer Weight Decay

🐍python
1def configure_layerwise_decay(
2    model: nn.Module,
3    base_lr: float = 1e-3,
4    base_decay: float = 1e-4,
5    decay_rate: float = 0.9
6) -> torch.optim.AdamW:
7    """
8    Configure optimizer with layer-wise learning rate and weight decay.
9
10    Deeper layers get lower learning rates and weight decay.
11    Useful for fine-tuning pretrained models.
12
13    Args:
14        model: Neural network model
15        base_lr: Learning rate for top layer
16        base_decay: Weight decay for top layer
17        decay_rate: Multiplicative factor per layer (< 1 for decay)
18
19    Returns:
20        Configured optimizer
21    """
22    # Get all named modules in order (shallow to deep)
23    named_modules = list(model.named_modules())
24
25    param_groups = []
26    seen_params = set()
27
28    # Reverse order: deepest layers first (highest LR)
29    for layer_idx, (name, module) in enumerate(reversed(named_modules)):
30        if not list(module.parameters(recurse=False)):
31            continue
32
33        layer_lr = base_lr * (decay_rate ** layer_idx)
34        layer_decay = base_decay * (decay_rate ** layer_idx)
35
36        for param_name, param in module.named_parameters(recurse=False):
37            if id(param) in seen_params:
38                continue
39            seen_params.add(id(param))
40
41            # No decay for biases
42            wd = 0.0 if 'bias' in param_name else layer_decay
43
44            param_groups.append({
45                'params': [param],
46                'lr': layer_lr,
47                'weight_decay': wd,
48                'name': f'{name}.{param_name}'
49            })
50
51    return torch.optim.AdamW(param_groups)

Usage Example

🐍python
1# Simple layer-specific decay (recommended)
2optimizer = configure_optimizer_with_layer_decay(
3    model=model,
4    learning_rate=1e-3,
5    weight_decay=1e-4
6)
7
8# Verify configuration
9for i, group in enumerate(optimizer.param_groups):
10    print(f"Group {i}: {group.get('name', 'unnamed')}")
11    print(f"  Weight decay: {group['weight_decay']}")
12    print(f"  Params: {sum(p.numel() for p in group['params']):,}")

Summary

In this section, we covered adaptive weight decay:

  1. Weight decay: Regularization via parameter shrinkage
  2. No decay for biases: Biases don't contribute to overfitting
  3. No decay for norms: LayerNorm/BatchNorm parameters excluded
  4. Constant λ: Recommended for AdamW (no scheduling)
  5. Parameter groups: PyTorch mechanism for per-group settings
Parameter TypeWeight Decay
Weight matrices1e-4
Bias terms0
LayerNorm parameters0
BatchNorm parameters0
Attention weights1e-4
Looking Ahead: We have configured regularization via weight decay. The final section covers gradient clipping—a technique for preventing exploding gradients that can destabilize training.

With weight decay configured, we address gradient stability.