Chapter 9
18 min read
Section 46 of 76

CLIP and Contrastive Learning

Text-to-Image Foundations

Learning Objectives

By the end of this section, you will be able to:

  1. Explain the principles of contrastive learning and why it creates useful representations
  2. Describe CLIP's dual-encoder architecture and training procedure
  3. Derive the InfoNCE contrastive loss function
  4. Understand why CLIP embeddings are well-suited for text-to-image diffusion
  5. Implement CLIP text encoding for diffusion models

Contrastive Learning Fundamentals

Contrastive learning is a self-supervised technique that learns representations by comparing similar (positive) and dissimilar (negative) pairs. The key insight:

Core Principle: Learn embeddings where similar items are close together and dissimilar items are far apart in the embedding space. No explicit labels needed - just pairs of related data.

Why Contrastive Learning Works

Instead of predicting specific outputs (like class labels), contrastive learning teaches the model to:

  1. Identify invariances: Learn what makes two items "the same" despite surface differences
  2. Capture semantics: Group items by meaning, not pixel-level similarity
  3. Scale efficiently: Use the batch itself as negative examples

The Image-Text Alignment Problem

For text-to-image generation, we need embeddings where:

  • Similar descriptions map to similar embedding regions
  • Text embeddings are compatible with visual concepts
  • The space captures compositional meaning

CLIP solves this by training on 400 million image-text pairs from the internet, learning a joint embedding space for both modalities.


CLIP Architecture

CLIP (Contrastive Language-Image Pre-training) consists of two encoders that map images and text to a shared embedding space:

Dual Encoder Design

ComponentArchitectureOutput
Image EncoderVision Transformer (ViT) or ResNet[B, D] pooled or [B, N, D] patches
Text EncoderTransformer (GPT-style)[B, L, D] sequence + [B, D] pooled
ProjectionLinear layersShared D-dimensional space

Text Encoder Details

The text encoder is a 12-layer Transformer with:

  • Tokenizer: BPE (Byte Pair Encoding) with 49,152 vocabulary
  • Max length: 77 tokens (including special tokens)
  • Architecture: Decoder-only Transformer (like GPT)
  • Output: 768-dim (ViT-B) or 1024-dim (ViT-L)

Two Types of Text Embeddings

CLIP provides both:
  • Sequence embeddings [B, L, D]: Per-token representations, used for cross-attention
  • Pooled embedding [B, D]: The end-of-text token embedding, used for global conditioning
Diffusion models use both: sequence for cross-attention, pooled for AdaGN/AdaLN.

Image Encoder Details

While not directly used in text-to-image generation, understanding the image encoder helps explain why CLIP embeddings work:

  • ViT-B/32: Patches of 32x32, 12 layers, 768-dim
  • ViT-L/14: Patches of 14x14, 24 layers, 1024-dim (most common for SD)
  • CLS token: Aggregates global image information

The Contrastive Training Objective

CLIP uses InfoNCE loss (also called NT-Xent) to align image and text embeddings:

Setup

Given a batch of NN image-text pairs:

  • Ii\mathbf{I}_i = image encoder output for image ii
  • Ti\mathbf{T}_i = text encoder output for text ii
  • Positive pairs: (Ii,Ti)(\mathbf{I}_i, \mathbf{T}_i)for same ii
  • Negative pairs: (Ii,Tj)(\mathbf{I}_i, \mathbf{T}_j)for iji \neq j

Similarity Matrix

First, compute cosine similarities between all pairs:

Sij=IiTjIiTjexp(τ)S_{ij} = \frac{\mathbf{I}_i \cdot \mathbf{T}_j}{\|\mathbf{I}_i\| \|\mathbf{T}_j\|} \cdot \exp(\tau)

where τ\tau is a learned temperature parameter. This creates an N×NN \times N similarity matrix.

InfoNCE Loss

The loss maximizes similarity of positive pairs relative to negatives:

Li2t=1Ni=1Nlogexp(Sii)j=1Nexp(Sij)\mathcal{L}_{\text{i2t}} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(S_{ii})}{\sum_{j=1}^{N} \exp(S_{ij})}

Lt2i=1Ni=1Nlogexp(Sii)j=1Nexp(Sji)\mathcal{L}_{\text{t2i}} = -\frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(S_{ii})}{\sum_{j=1}^{N} \exp(S_{ji})}

LCLIP=12(Li2t+Lt2i)\mathcal{L}_{\text{CLIP}} = \frac{1}{2}(\mathcal{L}_{\text{i2t}} + \mathcal{L}_{\text{t2i}})

Why Symmetric Loss?

The symmetric loss ensures both modalities are equally aligned:
  • Image-to-text: Each image should find its text among all texts
  • Text-to-image: Each text should find its image among all images
This creates a truly joint embedding space.

Implementation

🐍python
1Imports

Standard PyTorch functional module for cross-entropy loss.

4Function Signature

Takes normalized image and text embeddings plus temperature. Both embeddings should be L2-normalized.

10InfoNCE Loss

Also known as NT-Xent. Treats each (image_i, text_i) pair as positive, all other combinations as negatives.

24Similarity Matrix

Dot product of embeddings gives cosine similarity (since normalized). Temperature scales the sharpness of the distribution.

28Diagonal Labels

The i-th image should match the i-th text. Labels are simply [0, 1, 2, ..., B-1].

31Image-to-Text Loss

Row-wise softmax: for each image, maximize probability of its paired text among all texts in the batch.

34Text-to-Image Loss

Column-wise softmax: for each text, maximize probability of its paired image among all images in the batch.

37Symmetric Loss

Average both directions. This makes the embedding space symmetric - both modalities pull toward each other.

31 lines without explanation
1import torch
2import torch.nn.functional as F
3
4def contrastive_loss(
5    image_embeddings: torch.Tensor,  # [B, D]
6    text_embeddings: torch.Tensor,   # [B, D]
7    temperature: float = 0.07,
8) -> torch.Tensor:
9    """
10    CLIP-style contrastive loss (InfoNCE).
11
12    Positive pairs: (image_i, text_i) for same index i
13    Negative pairs: All other combinations
14
15    Args:
16        image_embeddings: L2-normalized image embeddings [batch, dim]
17        text_embeddings: L2-normalized text embeddings [batch, dim]
18        temperature: Softmax temperature (learnable in CLIP)
19
20    Returns:
21        Average of image-to-text and text-to-image losses
22    """
23    batch_size = image_embeddings.shape[0]
24
25    # Compute similarity matrix
26    # [B, D] @ [D, B] -> [B, B]
27    logits = image_embeddings @ text_embeddings.T / temperature
28
29    # Labels: diagonal elements are positive pairs
30    labels = torch.arange(batch_size, device=logits.device)
31
32    # Image-to-text: for each image, predict correct text
33    loss_i2t = F.cross_entropy(logits, labels)
34
35    # Text-to-image: for each text, predict correct image
36    loss_t2i = F.cross_entropy(logits.T, labels)
37
38    # Symmetric loss
39    return (loss_i2t + loss_t2i) / 2

Using CLIP for Diffusion

CLIP embeddings are ideal for text-to-image diffusion for several reasons:

1. Image-Aligned Text Representations

Unlike pure language models (BERT, GPT), CLIP text embeddings are trained to correlate with visual concepts:

  • "red car" embeds near images of red cars
  • "sunset over mountains" embeds near such scenes
  • Visual attributes (color, texture, composition) are well-represented

2. Open Vocabulary

CLIP handles arbitrary text, not just predefined classes:

  • Novel combinations: "astronaut riding a horse"
  • Specific styles: "in the style of Van Gogh"
  • Detailed descriptions: Multiple attributes composed

3. Semantic Interpolation

The embedding space supports smooth interpolation:

cinterp=(1α)ccat+αcdog\mathbf{c}_{\text{interp}} = (1-\alpha) \cdot \mathbf{c}_{\text{cat}} + \alpha \cdot \mathbf{c}_{\text{dog}}

This produces semantically meaningful intermediate concepts, enabling prompt interpolation and blending.

Why Freeze CLIP During Diffusion Training?

ApproachProsCons
Frozen CLIPPreserves alignment, faster trainingCan't adapt to new domains
Fine-tuned CLIPDomain adaptation possibleRisk of forgetting, expensive
LoRA on CLIPEfficient adaptationLimited capacity

Standard practice is to freeze CLIP because:

  • Its representations are already excellent for general images
  • The diffusion model learns to use these fixed representations
  • Prevents catastrophic forgetting of image-text alignment

PyTorch Implementation

Here's how to use CLIP for text conditioning in diffusion:

🐍python
1Imports

Using HuggingFace transformers for easy access to pre-trained CLIP models.

5Class Definition

Wrapper that extracts both sequence embeddings (for cross-attention) and pooled embeddings (for global conditioning).

12Model Configuration

Default to CLIP ViT-L/14 which is commonly used in Stable Diffusion. Max length 77 is CLIP's standard.

17Load Pre-trained Models

Load tokenizer and text model from HuggingFace. These are pre-trained on 400M image-text pairs.

22Freeze Parameters

Critical: freeze the text encoder during diffusion training. We use CLIP as a fixed feature extractor.

26Forward Method

Takes list of text strings, returns embeddings. last_hidden_state has per-token embeddings, pooler_output is the CLS token.

38Tokenization

Pad to max_length (77), truncate longer prompts. Returns input_ids and attention_mask tensors.

47Embedding Extraction

Run through frozen text model with no_grad for efficiency. Extract both sequence and pooled outputs.

55Unconditional Embeddings

Helper for CFG: get embeddings for empty prompts. These represent the 'null condition' for classifier-free guidance.

52 lines without explanation
1import torch
2import torch.nn as nn
3from transformers import CLIPTextModel, CLIPTokenizer
4
5class CLIPTextEncoder(nn.Module):
6    """
7    Wrapper for CLIP text encoder used in diffusion models.
8    Extracts both sequence and pooled embeddings.
9    """
10
11    def __init__(
12        self,
13        model_name: str = "openai/clip-vit-large-patch14",
14        max_length: int = 77,
15        device: str = "cuda",
16    ):
17        super().__init__()
18        self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
19        self.text_model = CLIPTextModel.from_pretrained(model_name)
20        self.max_length = max_length
21        self.device = device
22
23        # Freeze the text encoder
24        for param in self.text_model.parameters():
25            param.requires_grad = False
26
27    def forward(self, text: list[str]) -> dict[str, torch.Tensor]:
28        """
29        Encode text prompts to embeddings.
30
31        Args:
32            text: List of text prompts
33
34        Returns:
35            Dictionary with:
36            - 'last_hidden_state': [B, L, 768] sequence embeddings
37            - 'pooler_output': [B, 768] pooled embeddings (CLS token)
38        """
39        # Tokenize
40        tokens = self.tokenizer(
41            text,
42            padding="max_length",
43            max_length=self.max_length,
44            truncation=True,
45            return_tensors="pt",
46        )
47        tokens = {k: v.to(self.device) for k, v in tokens.items()}
48
49        # Get embeddings
50        with torch.no_grad():
51            outputs = self.text_model(**tokens)
52
53        return {
54            "last_hidden_state": outputs.last_hidden_state,
55            "pooler_output": outputs.pooler_output,
56        }
57
58    def get_unconditional_embeddings(self, batch_size: int) -> dict[str, torch.Tensor]:
59        """Get embeddings for empty/null prompts (for CFG)."""
60        empty_prompts = [""] * batch_size
61        return self.forward(empty_prompts)

Usage in Diffusion Training

🐍python
1# Initialize
2clip_encoder = CLIPTextEncoder(
3    model_name="openai/clip-vit-large-patch14",
4    device="cuda"
5)
6
7# During training
8prompts = ["a photo of a cat", "a painting of mountains"]
9text_embeddings = clip_encoder(prompts)
10
11# For cross-attention: use sequence embeddings
12context = text_embeddings["last_hidden_state"]  # [B, 77, 768]
13
14# For global conditioning: use pooled embeddings
15pooled = text_embeddings["pooler_output"]  # [B, 768]
16
17# CFG: also get unconditional embeddings
18uncond_embeddings = clip_encoder.get_unconditional_embeddings(batch_size=2)
19uncond_context = uncond_embeddings["last_hidden_state"]
20
21# Concatenate for batched CFG inference
22context_cfg = torch.cat([uncond_context, context], dim=0)  # [2B, 77, 768]

Memory Efficiency

Since CLIP is frozen, you can:
  • Pre-compute embeddings for your dataset offline
  • Use torch.no_grad() during forward pass
  • Move to CPU after encoding if GPU memory is tight

OpenCLIP and Variants

Several CLIP variants are used in modern diffusion models:

Model Comparison

ModelDimTraining DataUsed In
CLIP ViT-L/14768WIT-400M (OpenAI)SD 1.x, 2.x
OpenCLIP ViT-H/141024LAION-2BSD XL (first encoder)
OpenCLIP ViT-bigG1280LAION-2BSD XL (second encoder)

OpenCLIP

OpenCLIP is an open-source reproduction of CLIP trained on public datasets (LAION). Benefits:

  • Larger models: ViT-H, ViT-G, ViT-bigG
  • More training data: LAION-2B vs WIT-400M
  • Open weights: Fully reproducible

SDXL: Dual Text Encoders

Stable Diffusion XL uses two text encoders:

  • OpenCLIP ViT-bigG: Image-aligned, 1280-dim
  • CLIP ViT-L: Standard encoder, 768-dim

The embeddings are concatenated:

c=[cOpenCLIP;cCLIP]RL×2048\mathbf{c} = [\mathbf{c}_{\text{OpenCLIP}}; \mathbf{c}_{\text{CLIP}}] \in \mathbb{R}^{L \times 2048}

Why Two Encoders? Different encoders capture different aspects. The larger OpenCLIP model has better visual grounding while the standard CLIP provides consistent baseline quality. Together they improve prompt following.

Key Takeaways

  1. Contrastive learning creates aligned embedding spaces by pulling positive pairs together and pushing negatives apart
  2. CLIP's dual encoder maps images and text to a shared space using InfoNCE loss
  3. CLIP provides two outputs: sequence embeddings [B, L, D] for cross-attention, pooled [B, D] for global conditioning
  4. Image-aligned text embeddings are crucial - CLIP text representations correlate with visual concepts
  5. Freeze CLIP during diffusion training to preserve its learned alignment
  6. Modern systems use multiple encoders (OpenCLIP + CLIP) for better coverage
Looking Ahead: We've covered text encoding and cross-attention. The final piece of the text-to-image puzzle is efficiency: in the next section, we'll explore latent diffusion - how VAEs compress images to make high-resolution generation tractable.