Chapter 26
30 min read
Section 159 of 175

Inference in Graphical Models

Probabilistic Graphical Models

Learning Objectives

By the end of this section, you will be able to:

📚 Core Knowledge

  • • Understand the computational challenge of probabilistic inference
  • • Explain variable elimination and its complexity
  • • Describe belief propagation and message passing
  • • Know when inference is tractable vs. intractable

🔧 Practical Skills

  • • Implement variable elimination from scratch
  • • Apply the sum-product algorithm to trees
  • • Construct junction trees for exact inference
  • • Choose appropriate inference algorithms

🧠 Deep Learning Connections

  • Graph Neural Networks — Message passing directly inspired by belief propagation
  • Variational Autoencoders — Use approximate inference via ELBO optimization
  • Structured Prediction — CRF layers in neural networks use inference algorithms
  • Neural Message Passing — Learned messages replace hand-designed factors
Where You'll Apply This: Medical diagnosis systems, protein structure prediction, image segmentation, natural language processing, robotics localization, and any application requiring reasoning under uncertainty with structured dependencies.

The Big Picture

Graphical models compactly represent joint probability distributions. But representation is only half the story—we need inference algorithms to answer questions about these distributions. Given a graphical model, inference lets us compute:

📊

Marginal Probabilities

P(Xi)P(X_i) — probability of a single variable, summing over all others

🎯

Conditional Probabilities

P(XE=e)P(X | E = e) — posterior given evidence

MAP Assignment

argmaxXP(X)\arg\max_X P(X) — most likely configuration

The Inference Challenge

Consider a joint distribution over nn binary variables. The naive approach to computing P(X1)P(X_1) requires summing over 2n12^{n-1} configurations—exponential in the number of variables! Graphical models exploit conditional independence to make this tractable, but even then inference can be NP-hard in the worst case.

The Fundamental Trade-off

Tractable Cases
  • • Tree-structured graphs: O(n)O(n)
  • • Low treewidth: O(ndw)O(n \cdot d^w)
  • • Special structures (chains, polytrees)
Intractable Cases
  • • Dense graphs, grids
  • • High treewidth
  • • Requires approximate inference

Historical Development

📜
Judea Pearl (1982-1988)

Developed belief propagation and the message-passing paradigm for inference. His book "Probabilistic Reasoning in Intelligent Systems" (1988) established the foundations of graphical models. Won the Turing Award in 2011.

🔗
Junction Trees (1990s)

Lauritzen & Spiegelhalter developed the junction tree algorithm, which generalizes belief propagation to work on any graphical model. This became the standard for exact inference in practical systems.

🔄
Loopy BP & Variational Methods (2000s)

Researchers discovered that belief propagation often works well even on graphs with loops. Connections to variational methods and statistical physics led to new approximate inference algorithms like expectation propagation.

🧠
Neural Message Passing (2015-present)

Graph Neural Networks apply learnable message passing, directly inspired by belief propagation. Models like Graph Attention Networks (GAT) and Message Passing Neural Networks (MPNN) now power molecular property prediction, protein folding, and social network analysis.


Variable Elimination

Variable elimination is the most fundamental exact inference algorithm. It systematically removes variables from the model by summing them out, eventually leaving only the query variable(s).

The Algorithm

The key insight is that we can push summations inside products due to the distributive law:

Aϕ1(A,B)ϕ2(A,C)=[Aϕ1(A,B)ϕ2(A,C)]\sum_A \phi_1(A,B) \cdot \phi_2(A,C) = \left[\sum_A \phi_1(A,B) \cdot \phi_2(A,C)\right]

We can group factors containing AA and sum them out together, creating a new factor over {B,C}\{B, C\}.

Variable Elimination Steps

  1. 1

    Choose elimination order

    Order variables to eliminate (excluding query). Order affects complexity!

  2. 2

    For each variable to eliminate:

    • a. Find all factors containing this variable
    • b. Multiply these factors together
    • c. Sum out (marginalize) the variable
    • d. Add the new factor to the factor set
  3. 3

    Multiply remaining factors

    The result is the marginal probability P(Q)P(Q)

Interactive: Variable Elimination

Watch variable elimination compute P(D)P(D) step by step. Notice how factors are combined and variables are summed out in order, eventually leaving only the query variable.

Variable Elimination Algorithm

Computing P(D) by eliminating variables A → B → C

A
B
C
D
Query: P(D) | Elimination Order: A → B → C

Active Factors

φ₁(A)
φ₂(A,B)
φ₃(B,C)
φ₄(C,D)

Variable Elimination computes marginal probabilities by:

  1. Multiplying all factors containing the variable
  2. Summing out (marginalizing) that variable
  3. Creating a new factor over the remaining variables

Complexity: O(n · d^w) where w is the treewidth of the graph

Elimination Ordering

The elimination order dramatically affects computational cost. A poor order can create intermediate factors that are exponentially large.

Treewidth and Complexity

The complexity of variable elimination is O(ndw+1)O(n \cdot d^{w+1})where ww is the treewidth of the graph (induced by the elimination order), dd is the maximum domain size, and nn is the number of variables.

Good ordering (low treewidth)

• Chain: treewidth 1
• Tree: treewidth 1
• Grid n×n: treewidth n

Heuristics for ordering

• Min-neighbors: eliminate fewest connections first
• Min-fill: minimize new edges added
• Finding optimal is NP-hard!

Key Insight: The treewidth is the minimum over all elimination orderings of the maximum clique size minus 1. For trees, treewidth = 1, making inference linear. For dense graphs, treewidth can approach n-1, making inference intractable.

Belief Propagation

Belief propagation (BP), also called the sum-product algorithm, computes all marginals simultaneously through local message passing. Unlike variable elimination which processes one query at a time, BP can answer all marginal queries efficiently after a single pass.

The Sum-Product Algorithm

In BP, nodes (variables) and factors exchange messages. Each message summarizes what the sender "believes" about the recipient based on its local information and messages it received from others.

Message Update Rules

Variable to Factor Message
μxf(x)=gne(x)fμgx(x)\mu_{x \to f}(x) = \prod_{g \in \text{ne}(x) \setminus f} \mu_{g \to x}(x)

Product of all incoming factor messages except the recipient.

Factor to Variable Message
μfx(x)=yf(x,y)yne(f)xμyf(y)\mu_{f \to x}(x) = \sum_{\mathbf{y}} f(x, \mathbf{y}) \prod_{y \in \text{ne}(f) \setminus x} \mu_{y \to f}(y)

Marginalize the factor times incoming variable messages.

Final Belief (Marginal)
b(x)fne(x)μfx(x)b(x) \propto \prod_{f \in \text{ne}(x)} \mu_{f \to x}(x)

Product of all incoming messages, normalized to be a valid probability.

Interactive: Message Passing

Watch belief propagation in action. Observe how messages flow through the graph, carrying information from observed variables to update beliefs throughout the network. Notice how beliefs converge after a few iterations.

Belief Propagation (Sum-Product)

Message passing on a factor graph with D observed

AP(A=1)=0.50BP(B=1)=0.50CP(C=1)=0.50DP(D=1)=1.00
Iteration: 0 / 4Running...
Hidden Variable
Observed Variable
Message

Loopy Belief Propagation

On tree-structured graphs, belief propagation is exact and converges in two passes (leaves to root, root to leaves). But what about graphs with loops?

✓ When Loopy BP Works Well

  • • Single loop or sparse graph structure
  • • Weak correlations between variables
  • • Strong evidence at some nodes
  • • Attractive potentials (prefer agreement)

✗ When Loopy BP Fails

  • • Many tight loops (dense graphs)
  • • Strong frustrated interactions
  • • Messages may oscillate or diverge
  • • Beliefs can be overconfident
Double Counting: The main issue with loopy BP is that messages can travel around loops and count the same evidence multiple times. This leads to beliefs that are often more confident than the true marginals. Damping (averaging old and new messages) can help convergence but doesn't fix the approximation error.

Junction Tree Algorithm

The junction tree algorithm (also called clique tree algorithm) provides exact inference on any graphical model. It works by transforming the original graph into a tree of cliques, then running belief propagation on this tree.

Building the Junction Tree

The construction involves several steps, each with a specific purpose:

Junction Tree Algorithm

Exact inference in arbitrary graphical models

ABCD

Step 1: Moralize the Graph

Connect all parents of each node (marry the parents). Convert directed edges to undirected.

1 / 6

The Running Intersection Property

A valid junction tree must satisfy the running intersection property: if a variable appears in two cliques, it must appear in every clique on the unique path between them.

If X ∈ C₁ and X ∈ C₃, then X ∈ C₂ for all C₂ on path(C₁, C₃)

This property ensures that information about each variable is propagated correctly throughout the tree, guaranteeing exact inference.

Why Triangulation? Only chordal graphs (no cycles of length > 3 without a chord) have perfect tree decompositions. Triangulation adds edges to make the graph chordal, at the cost of potentially increasing treewidth.

Approximate Inference

When exact inference is intractable (high treewidth, large state spaces), we turn to approximate methods. The two main families are:

Sampling Methods

  • Gibbs Sampling: Iteratively sample each variable conditioned on current values of others
  • Importance Sampling: Weight samples by likelihood ratio
  • Particle Filtering: Sequential Monte Carlo for temporal models

Asymptotically exact but may require many samples

Variational Methods

  • Mean Field: Assume variables are independent, optimize per-variable distributions
  • Loopy BP: BP on graphs with cycles (approximate)
  • Expectation Propagation: Moment matching with tractable family

Deterministic optimization, fast but biased

Algorithm Comparison

Choosing the right inference algorithm depends on your graph structure, accuracy requirements, and computational budget. Use this interactive comparison to understand the trade-offs.

Inference Algorithm Comparison

Choose the right algorithm for your problem

Exact
Complexity
O(n · d^w)
Graph Types
Bayesian NetworksMarkov Random Fields
Best Use Case
One-time exact queries on moderate-sized graphs

Advantages

  • +Exact inference
  • +Works on any graph
  • +Simple to implement

Limitations

  • Exponential in treewidth
  • Single query only
  • No structure reuse
AlgorithmTypeComplexityTreesGeneral
Variable EliminationexactO(n · d^w)
Belief PropagationexactO(n · d²)
Junction TreeexactO(n · d^w)
Gibbs SamplingapproximateO(samples · n)
Variational InferenceapproximateO(iterations · n)
✓ = Exact/Optimal | ≈ = Approximate | ○ = Not Applicable | w = treewidth, d = domain size, n = nodes

Applications in Deep Learning

Inference algorithms from graphical models have profoundly influenced modern deep learning architectures:

🔗 Graph Neural Networks

GNNs directly implement message passing: each node aggregates messages from neighbors and updates its representation. The key difference is that message functions are learned rather than derived from probabilistic factors. Message Passing Neural Networks (MPNNs), Graph Attention Networks (GATs), and Graph Convolutional Networks (GCNs) all follow this paradigm.

🎯 Structured Prediction with CRFs

BiLSTM-CRF models for named entity recognition and POS tagging use inference algorithms at their core. The CRF layer computes the partition function and marginals using forward-backward (a special case of belief propagation on chains). This allows gradients to flow through the structured prediction layer.

🧬 Protein Structure Prediction

AlphaFold2 uses a form of message passing called "Evoformer" to reason about amino acid relationships. The attention mechanism can be seen as a soft, learnable version of belief propagation where "messages" are attention-weighted features between residue pairs.

🎲 Variational Autoencoders

VAEs perform approximate inference using the ELBO objective—the same framework used in variational inference for graphical models. The encoder network amortizes inference, learning to predict approximate posterior parameters directly from input data rather than running inference from scratch each time.

Classical AlgorithmDeep Learning AnalogKey Innovation
Belief PropagationGraph Neural NetworksLearned message functions
Forward-BackwardCRF Loss LayerDifferentiable structured prediction
Variational InferenceVAE + ELBOAmortized inference with encoders
Junction TreeHierarchical AttentionMulti-scale message passing
Gibbs SamplingDenoising Score MatchingLearned transition kernels

Python Implementation

Let's implement the core inference algorithms from scratch. This implementation covers variable elimination and the essence of belief propagation.

Inference Algorithms Implementation
🐍graphical_model_inference.py
1Import NumPy

NumPy provides efficient array operations for factor manipulation and marginalization.

4Factor Class

A factor is a function over a subset of variables. We represent it as a dictionary mapping variable assignments to probabilities.

5Constructor

Initialize with scope (variables), cardinalities (domain sizes), and values (the actual probabilities as a numpy array).

12Factor Product

Multiply two factors by combining their scopes and multiplying corresponding values. This is the core operation in inference.

EXAMPLE
φ(A,B) × φ(B,C) = φ(A,B,C)
23Marginalization

Sum out a variable by summing over all its values. This reduces the factor scope by one variable.

EXAMPLE
Σ_A φ(A,B) = φ(B)
35Variable Elimination

Main algorithm: iterate through elimination order, multiply all factors containing the variable, sum it out, and add new factor to the set.

39Find Relevant Factors

For each variable to eliminate, find all factors that mention it. These will be combined into one larger factor.

45Multiply Factors

Combine all relevant factors using repeated factor product. The result contains the variable plus any other variables in the factors.

50Marginalize Variable

Sum out the eliminated variable. The resulting factor is smaller and no longer depends on that variable.

55Return Result

After eliminating all non-query variables, multiply remaining factors and normalize to get the marginal probability.

60Belief Propagation

Message-passing algorithm. Initialize messages to 1, then iteratively update until convergence.

65Message Update

Message from i to j is computed by multiplying local factor with all incoming messages (except from j), then marginalizing over i.

73Compute Beliefs

After convergence, the belief at each node is the product of its local factor and all incoming messages. Normalize to get probability.

97 lines without explanation
1import numpy as np
2from collections import defaultdict
3
4class Factor:
5    def __init__(self, scope, cardinalities, values):
6        """
7        scope: list of variable names (e.g., ['A', 'B'])
8        cardinalities: dict of variable -> domain size
9        values: numpy array of probabilities
10        """
11        self.scope = scope
12        self.cardinalities = cardinalities
13        self.values = values.reshape([cardinalities[v] for v in scope])
14
15    def multiply(self, other):
16        """Compute factor product φ1 × φ2."""
17        new_scope = list(set(self.scope) | set(other.scope))
18        new_cards = {**self.cardinalities, **other.cardinalities}
19
20        # Broadcast and multiply (simplified)
21        result_shape = [new_cards[v] for v in new_scope]
22        result = np.zeros(result_shape)
23        # ... (full implementation handles alignment)
24        return Factor(new_scope, new_cards, result)
25
26    def marginalize(self, variable):
27        """Sum out a variable: Σ_X φ(X, Y) = φ(Y)."""
28        if variable not in self.scope:
29            return self
30
31        axis = self.scope.index(variable)
32        new_values = np.sum(self.values, axis=axis)
33        new_scope = [v for v in self.scope if v != variable]
34        return Factor(new_scope, self.cardinalities, new_values)
35
36
37def variable_elimination(factors, query, elimination_order):
38    """
39    Compute P(query) by eliminating variables in order.
40
41    factors: list of Factor objects
42    query: variable to compute marginal for
43    elimination_order: list of variables to eliminate
44    """
45    working_factors = list(factors)
46
47    for variable in elimination_order:
48        if variable == query:
49            continue
50
51        # Find factors containing this variable
52        relevant = [f for f in working_factors if variable in f.scope]
53        remaining = [f for f in working_factors if variable not in f.scope]
54
55        # Multiply all relevant factors
56        product = relevant[0]
57        for f in relevant[1:]:
58            product = product.multiply(f)
59
60        # Marginalize out the variable
61        new_factor = product.marginalize(variable)
62        working_factors = remaining + [new_factor]
63
64    # Multiply remaining factors and normalize
65    result = working_factors[0]
66    for f in working_factors[1:]:
67        result = result.multiply(f)
68    return result.values / result.values.sum()
69
70
71def belief_propagation(graph, factors, max_iter=100, tol=1e-6):
72    """
73    Sum-product belief propagation on a tree.
74
75    Returns: dict of variable -> marginal probability
76    """
77    messages = defaultdict(lambda: np.ones(2))  # Initialize to uniform
78
79    for iteration in range(max_iter):
80        old_messages = dict(messages)
81
82        for (i, j) in graph.edges:
83            # Message from i to j
84            incoming = [messages[(k, i)] for k in graph.neighbors(i) if k != j]
85            local_factor = factors[i]
86
87            # Multiply factor with incoming messages
88            product = local_factor.copy()
89            for msg in incoming:
90                product *= msg
91
92            # Marginalize over sender
93            messages[(i, j)] = np.sum(product, axis=0)
94            messages[(i, j)] /= messages[(i, j)].sum()  # Normalize
95
96        # Check convergence
97        if all(np.allclose(messages[k], old_messages.get(k, messages[k]))
98               for k in messages):
99            break
100
101    # Compute beliefs
102    beliefs = {}
103    for node in graph.nodes:
104        incoming = [messages[(k, node)] for k in graph.neighbors(node)]
105        belief = factors[node].copy()
106        for msg in incoming:
107            belief *= msg
108        beliefs[node] = belief / belief.sum()
109
110    return beliefs
Production Libraries: For real applications, use specialized libraries:
  • pgmpy — Python library for graphical models with full inference
  • Pyro / PyTorch — Probabilistic programming with GPU support
  • libDAI — C++ library with Python bindings, very efficient
  • OpenGM — Optimized for large-scale discrete inference

Knowledge Check

Test your understanding of inference algorithms in graphical models:

Knowledge Check

Question 1 of 8

What is the main computational bottleneck in variable elimination?

Score: 0/0

Summary

Key Takeaways

Variable elimination computes marginals by systematically summing out variables one at a time.
Belief propagation uses message passing to compute all marginals simultaneously—exact on trees, approximate on loopy graphs.
Junction tree algorithm provides exact inference on any graph by converting it to a tree of cliques.
Treewidth determines complexity—low treewidth means tractable exact inference, high treewidth requires approximations.
Elimination ordering dramatically affects efficiency—finding the optimal order is NP-hard, but heuristics work well.
Approximate inference (sampling, variational) is needed when exact methods are intractable.
Message passing directly inspired Graph Neural Networks—a key architectural paradigm in modern deep learning.
Running intersection property ensures correct information flow in junction trees.

Looking Ahead

In the next section, we'll explore Learning Graphical Model Structure—how to discover the graph structure itself from data. This is the inverse problem: instead of performing inference given a known structure, we infer which edges and dependencies exist. Structure learning is essential for discovering causal relationships and building interpretable models from observational data.

Loading comments...