Chapter 10
16 min read
Section 51 of 75

Label Smoothing Loss

Training Pipeline

Introduction

Label smoothing is a regularization technique that prevents the model from becoming overconfident. Instead of training with hard labels (0 or 1), we use soft labels that assign small probability to incorrect classes. This improves generalization and translation quality.


The Problem: Overconfidence

Hard Labels

πŸ“text
1Target word: "dog" (vocabulary index 45)
2
3Hard label distribution:
4  P(dog) = 1.0
5  P(cat) = 0.0
6  P(the) = 0.0
7  P(runs) = 0.0
8  ... (all others = 0.0)
9
10The model is trained to be 100% confident in "dog".

Why This is a Problem

πŸ“text
1Issues with hard labels:
21. Overconfidence: Model outputs near-1.0 probabilities
32. Poor calibration: Predicted confidence β‰  actual accuracy
43. Overfitting: Model memorizes exact target distribution
54. Gradient saturation: Very confident predictions have small gradients
6
7Example:
8  Model predicts P(dog) = 0.99
9  True label: P(dog) = 1.0
10  Cross-entropy loss is very small
11
12  But maybe "hound" or "puppy" would also be valid translations!

Label Smoothing Solution

Soft Labels

πŸ“text
1Target word: "dog" (index 45)
2Smoothing: Ξ΅ = 0.1
3Vocabulary size: V = 32000
4
5Smoothed label distribution:
6  P(dog) = 1 - Ξ΅ + Ξ΅/V = 0.9 + 0.1/32000 β‰ˆ 0.9
7  P(cat) = Ξ΅/V = 0.1/32000 β‰ˆ 0.000003
8  P(the) = Ξ΅/V β‰ˆ 0.000003
9  ... (all non-target = Ξ΅/V)
10
11Now model is trained to be ~90% confident, not 100%.

Mathematical Formulation

Standard cross-entropy:

L=βˆ’βˆ‘yiΓ—log⁑(pi)L = -\sum y_i \times \log(p_i)

With hard labels (one-hot):

ytarget=1,yother=0y_{\text{target}} = 1, y_{\text{other}} = 0

L=βˆ’log⁑(ptarget)L = -\log(p_{\text{target}})

With label smoothing:

ytarget=1βˆ’Ξ΅+Ξ΅/Vy_{\text{target}} = 1 - \varepsilon + \varepsilon/V

yother=Ξ΅/Vy_{\text{other}} = \varepsilon/V

L=βˆ’(1βˆ’Ξ΅+Ξ΅/V)Γ—log⁑(ptarget)βˆ’βˆ‘(Ξ΅/V)Γ—log⁑(pi)L = -(1 - \varepsilon + \varepsilon/V) \times \log(p_{\text{target}}) - \sum (\varepsilon/V) \times \log(p_i)

Simplification:

L=(1βˆ’Ξ΅)Γ—CE(p,target)+Ρ×CE(p,uniform)L = (1 - \varepsilon) \times \text{CE}(p, \text{target}) + \varepsilon \times \text{CE}(p, \text{uniform})


Implementation

Label Smoothing Loss Class

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6
7class LabelSmoothingLoss(nn.Module):
8    """
9    Label smoothing loss for sequence-to-sequence models.
10
11    Implements the label smoothing regularization technique where
12    the target distribution is smoothed by mixing with a uniform distribution.
13
14    Args:
15        vocab_size: Size of the vocabulary
16        smoothing: Label smoothing factor (default: 0.1)
17        pad_id: Padding token ID to ignore in loss (default: 0)
18
19    Example:
20        >>> criterion = LabelSmoothingLoss(vocab_size=32000, smoothing=0.1)
21        >>> loss = criterion(logits, target_ids)
22    """
23
24    def __init__(
25        self,
26        vocab_size: int,
27        smoothing: float = 0.1,
28        pad_id: int = 0
29    ):
30        super().__init__()
31
32        self.vocab_size = vocab_size
33        self.smoothing = smoothing
34        self.pad_id = pad_id
35
36        # Pre-compute smoothed distribution
37        # Confidence for correct class
38        self.confidence = 1.0 - smoothing
39
40        # Smoothing value spread over vocabulary
41        self.smoothing_value = smoothing / vocab_size
42
43    def forward(
44        self,
45        logits: torch.Tensor,
46        target: torch.Tensor
47    ) -> torch.Tensor:
48        """
49        Compute label smoothing loss.
50
51        Args:
52            logits: Model predictions [batch, seq_len, vocab_size]
53            target: Target token IDs [batch, seq_len]
54
55        Returns:
56            loss: Scalar loss value
57        """
58        batch_size, seq_len, vocab_size = logits.shape
59
60        # Reshape for computation
61        logits = logits.reshape(-1, vocab_size)  # [batch*seq, vocab]
62        target = target.reshape(-1)              # [batch*seq]
63
64        # Create smoothed target distribution
65        with torch.no_grad():
66            # Start with uniform smoothing
67            smooth_target = torch.full_like(
68                logits, self.smoothing_value
69            )
70
71            # Set correct class to confidence value
72            smooth_target.scatter_(
73                1,
74                target.unsqueeze(1),
75                self.confidence
76            )
77
78            # Mask padding positions (set to 0)
79            pad_mask = (target == self.pad_id)
80            smooth_target[pad_mask] = 0
81
82        # Compute KL divergence loss
83        log_probs = F.log_softmax(logits, dim=-1)
84        loss = F.kl_div(
85            log_probs,
86            smooth_target,
87            reduction='none'
88        ).sum(dim=-1)
89
90        # Mask padding and compute mean
91        non_pad = ~pad_mask
92        loss = loss[non_pad].mean()
93
94        return loss
95
96
97class LabelSmoothingCrossEntropy(nn.Module):
98    """
99    Alternative implementation using cross-entropy formulation.
100
101    More numerically stable and directly computes the smoothed loss.
102    """
103
104    def __init__(
105        self,
106        smoothing: float = 0.1,
107        pad_id: int = 0
108    ):
109        super().__init__()
110        self.smoothing = smoothing
111        self.pad_id = pad_id
112
113    def forward(
114        self,
115        logits: torch.Tensor,
116        target: torch.Tensor
117    ) -> torch.Tensor:
118        """
119        Compute label smoothing loss.
120
121        L = (1 - Ξ΅) Γ— CE(p, target) + Ξ΅ Γ— CE(p, uniform)
122
123        Args:
124            logits: [batch, seq_len, vocab_size]
125            target: [batch, seq_len]
126
127        Returns:
128            loss: Scalar
129        """
130        vocab_size = logits.size(-1)
131
132        # Reshape
133        logits = logits.reshape(-1, vocab_size)
134        target = target.reshape(-1)
135
136        # Log probabilities
137        log_probs = F.log_softmax(logits, dim=-1)
138
139        # NLL loss (cross-entropy with log_probs)
140        nll_loss = F.nll_loss(
141            log_probs,
142            target,
143            ignore_index=self.pad_id,
144            reduction='none'
145        )
146
147        # Smooth loss (negative mean log probability)
148        smooth_loss = -log_probs.mean(dim=-1)
149
150        # Combined loss
151        loss = (1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss
152
153        # Mask padding
154        non_pad = (target != self.pad_id)
155        loss = loss[non_pad].mean()
156
157        return loss

Effect on Training

Gradient Behavior

Label smoothing affects gradients by making them denser. With hard labels, gradients are sparse (only the correct class matters). With smooth labels, all classes receive some gradient signal, which prevents gradient vanishing for confident predictions.

πŸ“text
1Key observations:
21. Hard labels: Gradient is sparse (only correct class matters)
32. Smooth labels: Gradient is dense (all classes get signal)
43. Smooth labels prevent gradient vanishing for confident predictions
54. Non-target classes get small but non-zero gradients

Choosing Smoothing Value

Guidelines

πŸ“text
1SMOOTHING PARAMETER (Ξ΅):
2────────────────────────
3
4Ξ΅ = 0.0:  No smoothing (standard cross-entropy)
5          β†’ Can overfit, overconfident predictions
6
7Ξ΅ = 0.1:  Standard smoothing (original Transformer paper)
8          β†’ Good default for most tasks
9          β†’ Confidence = 0.9 for correct class
10
11Ξ΅ = 0.2:  More regularization
12          β†’ Useful for smaller datasets
13          β†’ May hurt if model needs high confidence
14
15Ξ΅ > 0.3:  Usually too much
16          β†’ Model may underfit
17          β†’ Predictions become too uncertain
18
19
20TASK-SPECIFIC RECOMMENDATIONS:
21─────────────────────────────
22
23Machine Translation:  Ξ΅ = 0.1 (standard, well-tested)
24Summarization:        Ξ΅ = 0.1 - 0.15
25Language Modeling:    Ξ΅ = 0.0 - 0.1 (less smoothing often better)
26Classification:       Ξ΅ = 0.1 - 0.2 (depends on class count)
27Small Dataset:        Ξ΅ = 0.15 - 0.2 (more regularization)
28Large Dataset:        Ξ΅ = 0.05 - 0.1 (less needed)

Complete Training Loss Setup

TranslationLoss Class

🐍python
1class TranslationLoss(nn.Module):
2    """
3    Complete loss module for translation training.
4
5    Combines:
6    - Label smoothing
7    - Proper handling of padding
8    - Optional auxiliary losses
9
10    Args:
11        vocab_size: Target vocabulary size
12        pad_id: Padding token ID
13        smoothing: Label smoothing factor
14    """
15
16    def __init__(
17        self,
18        vocab_size: int,
19        pad_id: int = 0,
20        smoothing: float = 0.1
21    ):
22        super().__init__()
23
24        self.vocab_size = vocab_size
25        self.pad_id = pad_id
26        self.smoothing = smoothing
27
28        self.criterion = LabelSmoothingCrossEntropy(
29            smoothing=smoothing,
30            pad_id=pad_id
31        )
32
33    def forward(
34        self,
35        logits: torch.Tensor,
36        target_ids: torch.Tensor
37    ) -> dict:
38        """
39        Compute translation loss.
40
41        Args:
42            logits: Model predictions [batch, tgt_len-1, vocab]
43            target_ids: Target sequence [batch, tgt_len]
44                       (includes BOS at start, EOS at end)
45
46        Returns:
47            Dictionary with:
48                - loss: Main training loss
49                - nll_loss: Cross-entropy loss (for logging)
50                - accuracy: Token accuracy
51                - num_tokens: Number of non-pad tokens
52        """
53        # Shift targets: input is [BOS, ...], target is [..., EOS]
54        # logits should already be for positions [0, ..., len-2]
55        # target should be [1, ..., len-1]
56        target = target_ids[:, 1:]  # Remove BOS
57
58        # Compute main loss
59        loss = self.criterion(logits, target)
60
61        # Compute additional metrics
62        with torch.no_grad():
63            # Token accuracy
64            predictions = logits.argmax(dim=-1)
65            mask = (target != self.pad_id)
66            correct = (predictions == target) & mask
67            accuracy = correct.sum().float() / mask.sum().float()
68
69            # NLL loss (for perplexity computation)
70            nll_loss = F.cross_entropy(
71                logits.reshape(-1, self.vocab_size),
72                target.reshape(-1),
73                ignore_index=self.pad_id
74            )
75
76            num_tokens = mask.sum()
77
78        return {
79            'loss': loss,
80            'nll_loss': nll_loss,
81            'accuracy': accuracy,
82            'num_tokens': num_tokens
83        }

Summary

Label Smoothing Key Points

AspectHard LabelsSoft Labels
Target class prob1.01 - Ξ΅
Other class prob0.0Ξ΅ / V
ConfidenceOverconfidentCalibrated
GradientsSparseDense

Implementation Notes

  • Always ignore padding in loss computation
  • Use KL-divergence or CE formulation (equivalent)
  • Default Ξ΅ = 0.1 works well for translation
  • Monitor both smoothed and unsmoothed loss for comparison

Exercises

Implementation

  • Implement "confidence penalty" - an alternative that adds -H(p) to loss.
  • Add support for class-weighted label smoothing.
  • Implement adaptive label smoothing that increases with training.

Analysis

  • Compare BLEU scores with different smoothing values.
  • Plot probability calibration curves with/without smoothing.

In the next section, we'll implement Learning Rate Schedulingβ€”including the crucial warmup phase that stabilizes transformer training.

Loading comments...