Chapter 24
18 min read
Section 129 of 178

Pretext Tasks for Text

Self-Supervised Learning

Learning Objectives

By the end of this section, you will be able to:

  1. Understand Masked Language Modeling (MLM) and how BERT uses it to learn bidirectional representations
  2. Explain Causal Language Modeling (CLM) and why it's natural for text generation (GPT-style models)
  3. Compare sentence-level pretext tasks including Next Sentence Prediction (NSP) and Sentence Order Prediction (SOP)
  4. Understand Permutation Language Modeling and how XLNet combines the benefits of MLM and CLM
  5. Implement text pretext tasks in PyTorch and understand the training dynamics
  6. Choose appropriate pretext tasks based on downstream application requirements
Why This Matters: Text pretext tasks are the foundation of modern NLP. BERT, GPT, T5, and virtually every state-of-the-art language model starts with self-supervised pretraining on unlabeled text. Understanding these tasks is essential because: (1) they determine what representations the model learns, (2) different tasks are suited for different downstream applications, and (3) the choice between bidirectional (BERT) vs. autoregressive (GPT) approaches has profound implications for how models can be used.

The Story Behind Text Pretext Tasks

The quest for self-supervised learning in NLP began with a simple question: How can we leverage the vast amounts of unlabeled text on the internet to train better language models?

The Historical Context

Before 2018, NLP relied heavily on task-specific architectures and limited labeled data. Word embeddings like Word2Vec (2013) and GloVe (2014) showed that useful representations could be learned from unlabeled text, but these were static—the word "bank" had the same embedding whether referring to a financial institution or a riverbank.

YearDevelopmentKey Innovation
2013Word2VecStatic word embeddings from context
2014GloVeGlobal + local context for embeddings
2017TransformerSelf-attention replaces recurrence
2018ELMoContextualized embeddings from bidirectional LSTM
2018GPTCausal language modeling with Transformers
2018BERTMasked language modeling + bidirectional Transformers
2019XLNetPermutation language modeling
2019RoBERTaRobustly optimized BERT (no NSP)
2020GPT-3Massive scale CLM shows emergent abilities

The Key Insight

The breakthrough came from realizing that language itself provides natural supervision signals. Every sentence contains implicit information about word relationships, syntax, and semantics. By designing tasks that exploit this structure—predicting masked words, next words, or sentence relationships—we can train models to understand language without any human labels.

Self-Supervision = Structure Exploitation

A pretext task is "self-supervised" because the supervision signal comes from the data's inherent structure. We don't need humans to label "the" as an article or "walked" as a past-tense verb—we can create supervision by masking words and asking the model to predict them.

Language Modeling Foundations

Before diving into specific pretext tasks, let's establish the mathematical foundation of language modeling.

The Language Modeling Problem

Given a sequence of tokens x=(x1,x2,,xT)\mathbf{x} = (x_1, x_2, \ldots, x_T), a language model estimates the probability of the sequence:

P(x)=P(x1,x2,,xT)P(\mathbf{x}) = P(x_1, x_2, \ldots, x_T)

Using the chain rule of probability, we can factorize this in different ways, leading to different pretext tasks.

Two Fundamental Factorizations

ApproachFactorizationModel Type
Autoregressive (Left-to-Right)P(x) = ∏ᵢ P(xᵢ | x₁, ..., xᵢ₋₁)GPT, GPT-2, GPT-3
Bidirectional (Masked)P(xₘₐₛₖₑₐ | xᵥᵢₛᵢᵦₗₑ)BERT, RoBERTa

The autoregressive approach predicts each token given all previous tokens. The bidirectional approach masks some tokens and predicts them given the visible context (both left and right).

Quick Check

Why can't a bidirectional model like BERT directly generate text?


Masked Language Modeling (MLM)

Masked Language Modeling, introduced by BERT (Bidirectional Encoder Representations from Transformers), revolutionized NLP by enabling bidirectional pretraining.

The Core Idea

In MLM, we randomly mask some percentage of input tokens and train the model to predict the original tokens. This is analogous to the "cloze task" in psycholinguistics, where humans fill in blanks in sentences.

LMLM=ExD[iMlogP(xix\M;θ)]\mathcal{L}_{\text{MLM}} = -\mathbb{E}_{\mathbf{x} \sim \mathcal{D}} \left[ \sum_{i \in \mathcal{M}} \log P(x_i | \mathbf{x}_{\backslash \mathcal{M}}; \theta) \right]

Where:

  • M\mathcal{M} is the set of masked positions
  • x\M\mathbf{x}_{\backslash \mathcal{M}} represents the input with masked positions
  • θ\theta are the model parameters

BERT's Masking Strategy

A crucial detail in BERT's success is its masking strategy. Rather than always replacing selected tokens with [MASK], BERT uses:

ActionProbabilityPurpose
Replace with [MASK]80%Standard masking—model learns to predict from context
Replace with random token10%Prevents model from relying on [MASK] marker
Keep original token10%Trains model to preserve/use unmasked representations

Why Not 100% [MASK]?

If we always used [MASK], the model would learn that it only needs to predict when it sees [MASK]. During fine-tuning and inference, [MASK] tokens don't appear, creating a train-test mismatch. The 80-10-10 strategy ensures the model maintains good representations for all tokens.

Mathematical Details

For a sequence of TT tokens, BERT typically masks 0.15×T0.15 \times T tokens. The prediction for each masked position uses the Transformer's output representation:

P(xix\M)=softmax(Whi+b)P(x_i | \mathbf{x}_{\backslash \mathcal{M}}) = \text{softmax}(\mathbf{W} \cdot \mathbf{h}_i + \mathbf{b})

Where hi\mathbf{h}_i is the Transformer's hidden state at position ii, and WRV×d\mathbf{W} \in \mathbb{R}^{|V| \times d} projects to the vocabulary.


Interactive: Masked Language Modeling Demo

Experiment with masked language modeling below. Adjust the mask rate and observe how the model predicts masked tokens using bidirectional context.

Masked Language Modeling (MLM) - BERT Style
Original Text:

The cat sat on the warm sunny windowsill watching the birds fly by

How MLM Works:
  1. Randomly mask ~15% of input tokens with [MASK]
  2. Model predicts the original tokens using bidirectional context
  3. Loss is only computed on masked positions
  4. The model learns rich contextual representations
BERT's Actual Masking Strategy:
80%
Replace with [MASK]
10%
Replace with random word
10%
Keep original word

This prevents the model from only learning to predict [MASK] tokens and forces it to maintain good representations for all tokens.


Causal Language Modeling (CLM)

Causal Language Modeling, used by GPT-style models, takes an autoregressive approach where each token is predicted based only on previous tokens.

The Autoregressive Formulation

CLM factorizes the joint probability of a sequence as:

P(x)=t=1TP(xtx1,x2,,xt1)P(\mathbf{x}) = \prod_{t=1}^{T} P(x_t | x_1, x_2, \ldots, x_{t-1})

The training objective is to minimize the negative log-likelihood:

LCLM=t=1TlogP(xtx<t;θ)\mathcal{L}_{\text{CLM}} = -\sum_{t=1}^{T} \log P(x_t | x_{<t}; \theta)

Causal Attention Mask

To enforce the left-to-right constraint in Transformers, CLM uses a causal attention mask. This mask ensures that position tt can only attend to positions 1,2,,t1, 2, \ldots, t:

Maskij={0if jiif j>i\text{Mask}_{ij} = \begin{cases} 0 & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}

After softmax, e=0e^{-\infty} = 0, so future positions contribute zero attention weight.

CLM vs MLM Trade-offs

AspectCLM (GPT)MLM (BERT)
ContextLeft-only (unidirectional)Both directions (bidirectional)
Training efficiencyAll tokens used (100%)Only masked tokens (~15%)
Natural fitText generationText understanding
Zero-shot capabilityStrong (prompting)Weak (needs fine-tuning)
Pretraining speedFaster per tokenSlower per token

Quick Check

In a sequence of 100 tokens, how many tokens contribute to the training loss in CLM vs MLM?


Interactive: Causal Language Modeling Demo

Watch how GPT-style causal language modeling predicts tokens one at a time, using only the previous context. Notice the causal attention mask that prevents looking ahead.

Causal Language Modeling (CLM) - GPT Style
Step 1 of 5: Predicting token 2
Thequick
Causal Attention Mask (Token 2 can only attend to previous tokens):
The
?
The
?

Each token can only "see" itself and all previous tokens (causal/autoregressive masking).

Top Predictions for Next Token:
quick
15%
cat
12%
dog
10%
CLM (GPT) Characteristics:
  • Unidirectional: left-to-right only
  • Predicts next token given previous context
  • Uses causal attention mask
  • Natural for text generation
  • Training uses all positions (efficient)
vs MLM (BERT):
  • Bidirectional: sees full context
  • Predicts masked tokens
  • Uses full attention mask
  • Better for understanding tasks
  • Only 15% of tokens used for loss
CLM Training Objective:
L = -∑ᵢ log P(xᵢ | x₁, x₂, ..., xᵢ₋₁; θ)

Maximize the probability of each token given all previous tokens. This is equivalent to minimizing the negative log-likelihood (cross-entropy loss) summed over all positions.


Sentence-Level Pretext Tasks

Beyond token-level predictions, some pretext tasks operate at the sentence or document level.

Next Sentence Prediction (NSP)

BERT was originally trained with NSP as an auxiliary task. Given two sentences A and B:

  • IsNext (50%): B is the actual next sentence after A in the corpus
  • NotNext (50%): B is a random sentence from a different document

The model learns to classify whether B follows A using the [CLS] token representation:

P(IsNext[CLS])=σ(wTh[CLS]+b)P(\text{IsNext} | \text{[CLS]}) = \sigma(\mathbf{w}^T \mathbf{h}_{\text{[CLS]}} + b)

NSP's Limited Effectiveness

Later research (RoBERTa, ALBERT) found that NSP provides minimal benefit and may even hurt performance. The task might be too easy—the model can often distinguish sentences by topic alone, without learning true discourse coherence.

Sentence Order Prediction (SOP)

ALBERT replaced NSP with Sentence Order Prediction. Instead of sampling random sentences:

  • Positive: Two consecutive sentences from the same document in correct order
  • Negative: Same two sentences but swapped (B before A)

This forces the model to understand inter-sentence coherence rather than just topic detection.

TaskPositive SampleNegative SampleDifficulty
NSPConsecutive sentencesRandom sentence from different docEasy (topic detection)
SOPConsecutive sentencesSame sentences, swapped orderHard (coherence detection)

Interactive: Next Sentence Prediction Demo

Test your intuition on sentence relationships. Can you predict whether two sentences form a coherent sequence?

Next Sentence Prediction (NSP) - BERT Auxiliary Task
Sentence A

The cat jumped onto the warm windowsill.

Sentence B

It curled up and began to purr contentedly.

Does Sentence B naturally follow Sentence A in a document?

How NSP Training Works:
IsNext (50%)

Sentence B is the actual next sentence after A in the training corpus.

NotNext (50%)

Sentence B is a randomly sampled sentence from another document.

BERT Input Format for NSP:
[CLS] Sentence A tokens [SEP] Sentence B tokens [SEP]

The [CLS] token representation is used to predict IsNext/NotNext via a binary classifier.

Note: NSP's Effectiveness is Debated

Later research (RoBERTa, ALBERT) found that NSP may not significantly improve downstream task performance. Some alternatives:

  • RoBERTa: Removes NSP entirely, uses only MLM
  • ALBERT: Uses Sentence Order Prediction (SOP) instead
  • The issue: NSP might be too easy - topic prediction vs. coherence

Permutation Language Modeling

XLNet introduced Permutation Language Modeling (PLM) to combine the benefits of bidirectional context (like BERT) with autoregressive training (like GPT).

The Key Innovation

Instead of always predicting left-to-right, XLNet samples a random permutation z\mathbf{z} of the sequence [1,2,,T][1, 2, \ldots, T] and factorizes the likelihood according to that order:

LPLM=EzZT[t=1TlogP(xztxz<t;θ)]\mathcal{L}_{\text{PLM}} = -\mathbb{E}_{\mathbf{z} \sim \mathcal{Z}_T} \left[ \sum_{t=1}^{T} \log P(x_{z_t} | \mathbf{x}_{\mathbf{z}_{<t}}; \theta) \right]

For example, with sequence "The cat sat" and permutation [2, 3, 1]:

  1. Predict "cat" (position 2) with no context
  2. Predict "sat" (position 3) given "cat"
  3. Predict "The" (position 1) given "cat" and "sat"

Bidirectional Without [MASK]

Through permutation, each token can potentially attend to tokens both before and after it in the original sequence. When predicting "The" last, it sees both "cat" (originally after) and gets bidirectional information. Unlike BERT, there's no train-test mismatch from [MASK] tokens.

Two-Stream Self-Attention

XLNet uses a special "two-stream" attention mechanism to handle the position-content dependency:

  • Content stream: Standard self-attention, encodes both content and position
  • Query stream: Only sees position (not content) of the token to predict

This prevents the trivial solution where the model just copies the token it's trying to predict.

When to Use PLM vs MLM vs CLM

  • MLM (BERT): Best for understanding tasks, simplest to implement, widely supported
  • CLM (GPT): Best for generation tasks, zero-shot capabilities, most scalable
  • PLM (XLNet): Can outperform BERT on understanding tasks, but more complex

Interactive: Pretext Task Comparison

Compare the different pretext tasks side-by-side. Select tasks to see their characteristics, advantages, and best use cases.

Pretext Tasks Comparison
Select tasks to compare:
Masked Language Modeling (MLM)
Model: BERT, RoBERTa, ALBERT
Objective: Predict randomly masked tokens
Input: The [MASK] sat on the [MASK]
Context: Bidirectional
✓ Advantages:
  • Bidirectional context understanding
  • Rich contextual embeddings
  • Good for understanding tasks
✗ Disadvantages:
  • Only ~15% of tokens used for training signal
  • [MASK] token not seen during fine-tuning
  • Not natural for generation
Best for:
Text classificationNERQuestion answeringSentence embeddings
Causal Language Modeling (CLM)
Model: GPT, GPT-2, GPT-3
Objective: Predict next token left-to-right
Input: The cat sat → on
Context: Unidirectional (left-to-right)
✓ Advantages:
  • Natural for text generation
  • All tokens used for training
  • Simple, efficient training
  • Zero-shot capabilities
✗ Disadvantages:
  • Only sees left context
  • Cannot access future context
  • May need more parameters for understanding
Best for:
Text generationCode completionFew-shot learningDialog systems
Key Insight: The Pretext-Downstream Gap

The choice of pretext task fundamentally shapes what a model learns. MLM's bidirectional nature makes it excellent for understanding tasks, while CLM's autoregressive nature makes it natural for generation. Modern approaches often combine multiple objectives or use contrastive learning to bridge this gap.


PyTorch Implementation

Let's implement the core pretext tasks in PyTorch to understand the training dynamics.

Text Pretext Tasks Implementation
🐍text_pretext_tasks.py
8MLM Head Architecture

The MLM head transforms encoder hidden states into vocabulary logits. It includes a dense layer, layer normalization, and a decoder projection to vocabulary size.

28BERT's 80-10-10 Strategy

This function implements BERT's masking strategy: 80% [MASK], 10% random, 10% unchanged. The labels tensor uses -100 for non-masked positions (ignored by CrossEntropyLoss).

45Probabilistic Masking

We use Bernoulli sampling to randomly select ~15% of positions to mask. Special tokens (like [CLS], [SEP]) are never masked.

66MLM Loss Computation

CrossEntropyLoss automatically ignores positions with label -100, so loss is only computed on masked tokens.

82Causal Attention Mask

Creates an upper triangular matrix of -inf values. After softmax in attention, these become zeros, preventing tokens from attending to future positions.

94CLM Loss with Shifting

For autoregressive training, we shift so that position i predicts token at position i+1. This is the standard GPT-style training objective.

113NSP Head

NSP is a binary classification task using the [CLS] token's representation. The head is a simple linear layer mapping to 2 classes (IsNext/NotNext).

181 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6
7class MLMHead(nn.Module):
8    """Masked Language Modeling prediction head."""
9
10    def __init__(self, hidden_size: int, vocab_size: int):
11        super().__init__()
12        self.dense = nn.Linear(hidden_size, hidden_size)
13        self.layer_norm = nn.LayerNorm(hidden_size)
14        self.decoder = nn.Linear(hidden_size, vocab_size)
15        self.gelu = nn.GELU()
16
17    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
18        # Transform hidden states
19        x = self.dense(hidden_states)
20        x = self.gelu(x)
21        x = self.layer_norm(x)
22        # Project to vocabulary
23        logits = self.decoder(x)
24        return logits
25
26
27def create_mlm_masks(
28    input_ids: torch.Tensor,
29    vocab_size: int,
30    mask_token_id: int,
31    mask_prob: float = 0.15,
32    special_tokens_mask: Optional[torch.Tensor] = None
33) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
34    """
35    Create MLM masks following BERT's 80-10-10 strategy.
36
37    Returns:
38        masked_input_ids: Input with masked tokens
39        labels: Original tokens at masked positions, -100 elsewhere
40        mask_positions: Boolean tensor indicating masked positions
41    """
42    labels = input_ids.clone()
43    batch_size, seq_len = input_ids.shape
44
45    # Create probability matrix for masking
46    prob_matrix = torch.full((batch_size, seq_len), mask_prob)
47
48    # Don't mask special tokens
49    if special_tokens_mask is not None:
50        prob_matrix.masked_fill_(special_tokens_mask.bool(), value=0.0)
51
52    # Select positions to mask
53    masked_indices = torch.bernoulli(prob_matrix).bool()
54
55    # Set labels to -100 for non-masked positions (ignored in loss)
56    labels[~masked_indices] = -100
57
58    # 80% of masked: replace with [MASK]
59    indices_replaced = (
60        torch.bernoulli(torch.full((batch_size, seq_len), 0.8)).bool()
61        & masked_indices
62    )
63    input_ids[indices_replaced] = mask_token_id
64
65    # 10% of masked: replace with random token
66    indices_random = (
67        torch.bernoulli(torch.full((batch_size, seq_len), 0.5)).bool()
68        & masked_indices
69        & ~indices_replaced
70    )
71    random_words = torch.randint(vocab_size, (batch_size, seq_len))
72    input_ids[indices_random] = random_words[indices_random]
73
74    # 10% of masked: keep original (already in input_ids)
75
76    return input_ids, labels, masked_indices
77
78
79def compute_mlm_loss(
80    logits: torch.Tensor,
81    labels: torch.Tensor
82) -> torch.Tensor:
83    """Compute MLM loss only on masked positions."""
84    # logits: (batch, seq_len, vocab_size)
85    # labels: (batch, seq_len) with -100 for non-masked
86
87    loss_fct = nn.CrossEntropyLoss()  # Ignores -100 by default
88    loss = loss_fct(
89        logits.view(-1, logits.size(-1)),
90        labels.view(-1)
91    )
92    return loss
93
94
95class CLMHead(nn.Module):
96    """Causal Language Modeling prediction head."""
97
98    def __init__(self, hidden_size: int, vocab_size: int):
99        super().__init__()
100        self.decoder = nn.Linear(hidden_size, vocab_size, bias=False)
101
102    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103        return self.decoder(hidden_states)
104
105
106def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
107    """Create causal attention mask for CLM."""
108    # Upper triangular matrix of -inf (positions that shouldn't attend)
109    mask = torch.triu(
110        torch.full((seq_len, seq_len), float('-inf'), device=device),
111        diagonal=1
112    )
113    return mask
114
115
116def compute_clm_loss(
117    logits: torch.Tensor,
118    labels: torch.Tensor,
119    shift: bool = True
120) -> torch.Tensor:
121    """
122    Compute CLM loss.
123
124    If shift=True, logits at position i predict token at position i+1.
125    """
126    if shift:
127        # Shift so logits[i] predicts labels[i+1]
128        shift_logits = logits[..., :-1, :].contiguous()
129        shift_labels = labels[..., 1:].contiguous()
130    else:
131        shift_logits = logits
132        shift_labels = labels
133
134    loss_fct = nn.CrossEntropyLoss()
135    loss = loss_fct(
136        shift_logits.view(-1, shift_logits.size(-1)),
137        shift_labels.view(-1)
138    )
139    return loss
140
141
142class NSPHead(nn.Module):
143    """Next Sentence Prediction head."""
144
145    def __init__(self, hidden_size: int):
146        super().__init__()
147        self.classifier = nn.Linear(hidden_size, 2)  # IsNext / NotNext
148
149    def forward(self, cls_hidden_state: torch.Tensor) -> torch.Tensor:
150        return self.classifier(cls_hidden_state)
151
152
153# Example training loop for MLM
154def train_mlm_step(
155    model: nn.Module,
156    mlm_head: MLMHead,
157    optimizer: torch.optim.Optimizer,
158    input_ids: torch.Tensor,
159    attention_mask: torch.Tensor,
160    vocab_size: int,
161    mask_token_id: int
162) -> float:
163    """Single training step for MLM."""
164    model.train()
165    mlm_head.train()
166
167    # Create MLM masks
168    masked_ids, labels, _ = create_mlm_masks(
169        input_ids.clone(),
170        vocab_size,
171        mask_token_id
172    )
173
174    # Forward pass through encoder
175    hidden_states = model(masked_ids, attention_mask)
176
177    # MLM prediction
178    logits = mlm_head(hidden_states)
179
180    # Compute loss
181    loss = compute_mlm_loss(logits, labels)
182
183    # Backward pass
184    optimizer.zero_grad()
185    loss.backward()
186    optimizer.step()
187
188    return loss.item()

Summary

Text pretext tasks are the foundation of modern self-supervised learning in NLP. Each task offers distinct advantages:

Key Concepts

Pretext TaskKey IdeaBest For
MLM (BERT)Predict masked tokens using bidirectional contextUnderstanding tasks, classification, QA
CLM (GPT)Predict next token autoregressivelyText generation, zero-shot, dialog
NSPPredict if sentence B follows ASentence-pair tasks (debated utility)
SOP (ALBERT)Predict correct sentence orderDiscourse coherence, harder than NSP
PLM (XLNet)Predict tokens in random factorization orderUnderstanding + generation

Key Equations

  1. MLM Objective: LMLM=iMlogP(xix\M)\mathcal{L}_{\text{MLM}} = -\sum_{i \in \mathcal{M}} \log P(x_i | \mathbf{x}_{\backslash \mathcal{M}})
  2. CLM Objective: LCLM=t=1TlogP(xtx<t)\mathcal{L}_{\text{CLM}} = -\sum_{t=1}^{T} \log P(x_t | x_{<t})
  3. PLM Objective: LPLM=EztlogP(xztxz<t)\mathcal{L}_{\text{PLM}} = -\mathbb{E}_{\mathbf{z}} \sum_t \log P(x_{z_t} | \mathbf{x}_{\mathbf{z}_{<t}})
  4. BERT Masking: 80% [MASK] + 10% random + 10% unchanged

Looking Forward

In the next section, we'll explore pretext tasks for sequential data beyond text, including time series and audio. We'll see how the principles from text—predicting masked elements, autoregressive modeling, and contrastive objectives—transfer to other domains.


Knowledge Check

Test your understanding of text pretext tasks:

Text Pretext Tasks Quiz
Score: 0/10
Question 1 of 100 answered

What is the primary advantage of Masked Language Modeling (MLM) over Causal Language Modeling (CLM)?

A
It trains faster
B
It uses bidirectional context for each prediction
C
It requires less data
D
It generates better text

Exercises

Conceptual Questions

  1. Explain why BERT's 80-10-10 masking strategy is preferred over always using [MASK]. What problems would arise with 100% [MASK]?
  2. XLNet claims to get "the best of both worlds" from BERT and GPT. What specific limitations of each does it address, and what trade-offs does it introduce?
  3. Why might RoBERTa's removal of NSP improve performance? What does this suggest about designing pretext tasks?
  4. Compare the sample efficiency of MLM vs CLM. Which uses more gradient information per training example, and why?

Mathematical Exercises

  1. MLM Gradient Flow: Derive the gradient LMLMhi\frac{\partial \mathcal{L}_{\text{MLM}}}{\partial \mathbf{h}_i} for a masked position ii. Show that only masked positions receive gradient signal.
  2. Permutation Counting: For a sequence of length 5, how many distinct factorization orders does XLNet consider? How does this scale with sequence length?
  3. Expected Mask Count: In a batch of 32 sequences, each 512 tokens long with 15% mask probability, what is the expected number of masked tokens per batch? What is the variance?

Coding Exercises

  1. MLM Accuracy Tracking: Extend the training loop to compute and log MLM accuracy (percentage of correctly predicted masked tokens) during training.
  2. Dynamic Masking: Implement dynamic masking where different masks are applied to the same sequence across epochs (used in RoBERTa).
  3. Whole Word Masking: Implement whole-word masking where if any subword token of a word is masked, all subwords of that word are masked together.
  4. SOP Implementation: Implement the Sentence Order Prediction task. Create positive/negative pairs from a corpus and train a classifier.

Solution Hints

  • Exercise 1: Compare argmax of logits to labels where labels != -100
  • Exercise 2: Move masking to the dataloader and call it fresh each epoch
  • Exercise 3: Use the tokenizer's word_ids() to identify subword boundaries
  • Exercise 4: Sample consecutive sentence pairs; for negatives, swap order with 50% probability

Challenge Project

Build a Pretext Task Ablation Study: Train small Transformer models (4-6 layers) with different pretext tasks on a corpus like WikiText-103. Compare:

  • MLM only vs MLM + NSP vs MLM + SOP
  • Different mask rates (10%, 15%, 20%, 30%)
  • Dynamic vs static masking
  • Whole-word vs subword masking

Evaluate on downstream tasks (sentiment classification, NLI) and analyze which pretext configurations work best for which tasks.


Now that you understand the major pretext tasks for text, you're ready to explore how these principles extend to sequential data beyond natural language in the next section.