Chapter 03

Reinforcement Learning and Generative AI

How AI learns from interaction instead of labeled data, and how it creates entirely new content. From multi-armed bandits to GANs and diffusion models.

1. From Predictive AI to Generative AI

In Chapter 2 we studied machine learning models that learn a function \( y = f(x) \). These models are deterministic transformations: given the same input \( x \), they always produce the same output \( y \). If the model gets a particular sample wrong, it will always get that sample wrong (until retrained).

Definition

Predictive AI (Pred-AI) refers to deterministic models that map inputs to outputs via a fixed function \( f \). This is also called traditional AI. Given the same input, Pred-AI always produces the same output.

Definition

Generative AI (Gen-AI) refers to models that can create new content. Rather than just mapping input to a single fixed output, Gen-AI models can produce different outputs for the same input, adding new knowledge beyond what was seen during training.

Why This Matters

Pred-AI can classify images or predict prices, but it cannot write a new poem or generate a new image. Gen-AI breaks through this limitation. Reinforcement learning is one of the key techniques that makes Gen-AI possible—it provides a way for models to enhance their knowledge through interaction and feedback, rather than only from pre-labeled data.

2. Reinforcement Learning — Formal Definition

Definition

Reinforcement Learning (RL) is a type of machine learning where an agent learns to make decisions by interacting with an environment. The agent takes actions, receives feedback (rewards or penalties), and develops a strategy to maximize its total reward over time.

Why RL Exists

Many real-world problems cannot be solved by labeled datasets. A robot learning to walk, a program learning to play chess, or an algorithm optimizing ad placement—these problems require an agent to try things, observe results, and improve. RL provides the mathematical framework for this trial-and-error learning process.

The RL Ecosystem

Every RL problem has these components:

RL Components
  • Agent \( A \) — the learner/decision-maker.
  • Environment \( E \) — everything the agent interacts with.
  • States \( S = \{s_t = S(t)\} \) — the agent perceives the environment through states. At each time step \( t \), the environment is in some state \( s_t \).
  • Actions \( a \in \mathcal{A} \) — the agent interacts with the environment by choosing actions from a set of possible actions.
  • Reward \( R \) — after each action, the environment responds with a reward signal:
    • \( r > 0 \) means reward (the action was good).
    • \( r \leq 0 \) means regret (the action was bad).

The agent's goal is to learn a strategy (called a policy) that chooses the best action for each state, maximizing the total reward over time.

3. RL vs Supervised vs Unsupervised Learning

Machine learning has three main paradigms. Understanding how they differ clarifies when to use each one.

Supervised Learning

What it is: The model learns from labeled data—pairs of (input, correct output). It learns to map inputs to outputs by minimizing the difference between its predictions and the correct labels.

Used for: classification (is this email spam?), regression (what will the stock price be?).

Requires: a labeled dataset prepared in advance.

Unsupervised Learning

What it is: The model finds hidden patterns in data that has no labels. It discovers structure on its own—grouping similar items together or learning compressed representations.

Used for: clustering (group customers by behavior), auto-encoding (compress data into a compact representation).

Requires: data, but no labels.

Reinforcement Learning

What it is: The agent learns by interacting with an environment. It takes actions, receives rewards, and develops a strategy over time. There is no labeled dataset—the agent generates its own experience.

Used for: game playing, robotics, resource allocation, any problem where a sequence of decisions must be optimized.

Key difference: RL optimizes a long-term strategy, not just a single prediction. The reward for an action may not be immediately apparent—a chess move might only prove good or bad many moves later.

4. Formal RL Framework: \(\langle S, A, R \rangle\)

Formally, a reinforcement learning problem is defined as the triple \(\langle S, A, R \rangle\) over a timeline \( T \):

Definition
  • State function \( S: T \to \mathcal{S} \) — maps each time step to a state. At time \( t \), the environment is in state \( s_t = S(t) \).
  • Action function \( A: \mathcal{S} \to \mathcal{A} \) — maps each state to an action. This is the agent's policy—a deterministic function that says "in state \( s \), take action \( a \)."
  • Reward function \( R: \mathcal{S} \times \mathcal{A} \times \mathcal{S} \to \mathbb{R} \) — given a state \( s \), an action \( a \), and the resulting new state \( s' \), returns a numerical reward \( r \). This function encodes what is "good" and "bad" in the environment.
Why This Formalization Exists

This triple provides a complete mathematical specification of any RL problem. Once you define \( S \), \( A \), and \( R \), you can apply any RL algorithm to find the optimal policy. The formalization separates the problem definition from the solution method.

5. The Multi-Armed Bandit Problem

Definition

The multi-armed bandit problem is one of the simplest RL problems. It captures the fundamental tension in RL: exploration vs exploitation. Should the agent try something new (explore) or stick with what has worked so far (exploit)?

Setup

There are \( M \) slot machines, each with an unknown payout probability. A player has \( N \) coins. On each turn, the player inserts 1 coin into 1 machine:

  • If the machine pays out, the player receives up to \( K \) coins (drawn from that machine's payout distribution).
  • If the machine does not pay out, the player loses the coin.

The goal: maximize total winnings over all \( N \) plays.

RL Formulation

Bandit as RL

We can express this problem in the RL framework:

  • Environment \( E \): the \( M \) machines.
  • Agent \( A \): the player with \( N \) coins.
  • State \( S(t) = [m_1, \ldots, m_M] \): where \( m_i = 1 \) if machine \( i \) is randomly "hot" (ready to pay out) at time \( t \), and \( m_i = 0 \) otherwise. The agent does not observe the state directly.
  • Action \( A(S(t)) \): put a coin in machine \( a_t \).
  • Reward: \[ R(S(t), A(S(t))) = \begin{cases} -1 & \text{if machine } a_t \text{ is not hot} \\ \text{random payout from machine's distribution} & \text{otherwise} \end{cases} \]

The objective is:

\[ \max \sum_{t=1}^{N} R(S(t), A(t)) \]

6. Strategies for Multi-Armed Bandits

The core challenge: you don't know which machine is best. You must balance exploration (trying machines to learn about them) with exploitation (using the best machine you've found so far).

Basic Idea

Spend the first \( n < N \) turns exploring—trying each machine to estimate how good it is. Then exploit the best machine for the remaining \( N - n \) turns. The strategies below are more sophisticated versions of this idea.

6.1 Greedy Strategy (\(\varepsilon\)-greedy)

Definition

The \(\varepsilon\)-greedy strategy works as follows on each turn:

  • With probability \( 1 - \varepsilon \): choose the machine with the highest average reward observed so far (exploit).
  • With probability \( \varepsilon \): choose a machine at random (explore).

The parameter \( \varepsilon \in [0, 1] \) controls the exploration-exploitation balance. A small \( \varepsilon \) means mostly exploit; a large \( \varepsilon \) means mostly explore.

Why It Works

Pure exploitation (always pick the current best) can get stuck on a suboptimal machine if early observations were misleading. The random exploration with probability \( \varepsilon \) ensures the agent keeps gathering information about all machines, preventing it from locking in on a bad choice.

6.2 Upper Confidence Bound (UCB)

Definition

Instead of exploring randomly, UCB explores intelligently. On each turn, choose the machine that maximizes:

\[ Q(t, a) + c \cdot \sqrt{\frac{\ln t}{N_t(a)}} \]

where:

  • \( Q(t, a) \) = average reward from machine \( a \) so far.
  • \( N_t(a) \) = number of times machine \( a \) has been played.
  • \( c \) = a constant controlling the exploration bonus.
  • \( t \) = current time step.
Why UCB Works

The second term \( c \cdot \sqrt{\ln t / N_t(a)} \) is a bonus for uncertainty. If machine \( a \) has been played many times, \( N_t(a) \) is large, the bonus is small—we already know a lot about it. If machine \( a \) has rarely been played, the bonus is large—we should explore it because we're uncertain about its true value. As time goes on and all machines get tried, the bonus naturally shrinks, and the algorithm transitions from exploration to exploitation automatically.

6.3 Bayesian Strategy (Thompson Sampling)

Definition

Thompson Sampling uses Bayesian statistics. For each machine, the agent maintains a probability distribution over the machine's true payout rate. On each turn:

  1. Sample a value from each machine's probability distribution.
  2. Play the machine whose sampled value is highest.
  3. Observe the result and update the probability distribution for that machine.
Why It Works

When the agent has little data about a machine, its probability distribution is wide (high uncertainty), so the sampled values vary a lot—this naturally leads to exploration. When the agent has a lot of data, the distribution is narrow (low uncertainty), so the sampled values are close to the true mean—this naturally leads to exploitation. Thompson Sampling automatically balances exploration and exploitation through the width of probability distributions.

7. Applications of Multi-Armed Bandits

Real-World Uses
  • Content recommendation: A website has multiple articles/videos to show users. Each piece of content is a "machine." The reward is whether the user clicks or engages. The bandit algorithm learns which content to recommend.
  • A/B testing: Testing multiple variations of a webpage, drug treatment, or marketing campaign. Instead of running a fixed experiment, a bandit algorithm dynamically allocates more traffic to better-performing variants.
  • Real-time bidding: In online advertising auctions, the bid amount is the action, and the reward is the return on ad spend. The bandit learns the optimal bid strategy.

8. RL Algorithms: Model-Based vs Model-Free

RL algorithms fall into two categories based on whether the agent tries to understand how the environment works.

Definition

Model-based RL: The agent builds an internal model of the environment by observing state transitions and rewards. It then uses this model to plan ahead—simulating future scenarios to choose the best action.

When to use: When the environment is well-defined and does not change over time. The agent can invest effort in building an accurate model and then exploit it efficiently.

Definition

Model-free RL: The agent does not build a model of the environment. Instead, it directly learns which actions lead to good outcomes through trial-and-error, storing only the learned values or policy.

When to use: When the environment is unknown, complex, or changes over time. Building an accurate model would be too difficult or unreliable.

Why This Distinction Matters

Model-based methods are more sample-efficient (they learn faster from less data) because they can plan using their model. Model-free methods are more flexible and work in environments that are too complex to model. Most modern RL algorithms (Q-learning, REINFORCE, etc.) are model-free.

9. The REINFORCE Algorithm

Definition

REINFORCE is a policy gradient method. Instead of learning the value of each state-action pair (like Q-learning), it directly learns a policy \( \pi_\theta(a|s) \)—a function parameterized by \( \theta \) that outputs the probability of taking action \( a \) in state \( s \).

Why Policy Gradients Exist

Some RL problems have continuous action spaces (e.g., how much to turn a steering wheel), making it impractical to store a value for every possible action. Policy gradient methods handle this by learning a parameterized function that directly outputs actions or action probabilities. REINFORCE is the simplest such method.

Key Concepts

Trajectory

A trajectory \( \tau \) is a complete sequence of states and actions:

\[ \tau = s_0, a_0, s_1, a_1, \ldots, s_T \]

It records one full episode of the agent interacting with the environment.

Expected Reward

The expected total reward under policy parameters \( \theta \) is:

\[ \eta(\theta) = \mathbb{E}\left[\sum_{t=0}^{T} \gamma^t \cdot R(s_t, a_t)\right] \]

where \( \gamma \in [0, 1] \) is the discount factor. It controls how much the agent cares about future rewards vs immediate rewards. \( \gamma = 0 \) means only care about immediate reward; \( \gamma \) close to 1 means value future rewards almost as much as immediate ones.

The goal is to find:

\[ \theta^* = \arg\max_\theta \;\eta(\theta) \]

The Policy Gradient Derivation

To maximize \( \eta(\theta) \), we need its gradient with respect to \( \theta \). The key mathematical insight is:

Key Identity
\[ \nabla_\theta \;\mathbb{E}[f(\tau)] = \mathbb{E}\left[ \left(\nabla_\theta \log P_\theta(\tau)\right) \cdot f(\tau) \right] \]

This identity is crucial because it lets us estimate the gradient by sampling trajectories from the environment. We don't need to know the environment's transition dynamics.

The log probability of a trajectory can be expanded as:

\[ \log P_\theta(\tau) = \log \mu(s_0) + \sum_{t=0}^{T-1} \log \pi_\theta(a_t | s_t) + \sum_{t=0}^{T-1} \log P(s_{t+1} | s_t, a_t) \]

where \( \mu(s_0) \) is the initial state distribution and \( P(s_{t+1}|s_t,a_t) \) is the environment's transition probability.

Critical Simplification

Only the \( \pi_\theta(a_t|s_t) \) terms depend on \( \theta \). The initial state distribution and the environment transitions do not depend on \( \theta \). Therefore:

\[ \nabla_\theta \log P_\theta(\tau) = \sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | s_t) \]

This is why REINFORCE is model-free—the environment dynamics drop out of the gradient computation entirely.

Putting it all together, the policy gradient is:

\[ \nabla_\theta \;\eta(\theta) = \mathbb{E}\left[ \left(\sum_{t=0}^{T-1} \nabla_\theta \log \pi_\theta(a_t | s_t)\right) \cdot \left(\sum_{t=0}^{T} \gamma^t R(s_t, a_t)\right) \right] \]

Vanilla Policy Gradient Algorithm

Algorithm Steps

For each iteration:

  1. Collect trajectories: Execute the current policy \( \pi_\theta \) in the environment to generate trajectories \( \tau_1, \tau_2, \ldots \).
  2. Fit a baseline \( B \): A baseline reduces variance in the gradient estimate. Fit \( B(s_t) \) by minimizing:
    \[ \sum_t \left(R_{\geq t} - B(s_t)\right)^2 \]
    where \( R_{\geq t} \) is the total reward from time \( t \) onward.
  3. Update \( \theta \): Use the gradient:
    \[ \sum_t \nabla_\theta \log \pi_\theta(a_t | s_t) \cdot \left(R_{\geq t} - B(s_t)\right) \]
Why Subtract a Baseline

The baseline \( B(s_t) \) does not change the expected gradient (it is unbiased), but it reduces variance. Without a baseline, all actions in good trajectories get reinforced equally, even mediocre actions. Subtracting the baseline means only actions that performed better than average get reinforced.

10. Q-Learning

Definition

Q-learning is a model-free RL algorithm that learns Q-values: the expected total reward for taking action \( a \) in state \( s \), then following the optimal policy afterward.

\[ Q(s, a) = \text{expected total reward for taking action } a \text{ in state } s \]
Why Q-values Exist

If we knew the Q-value of every state-action pair, the optimal policy would be trivial: in each state, just pick the action with the highest Q-value. Q-learning provides a way to learn these values through experience.

The Update Rule (Bellman Equation)

After taking action \( a \) in state \( s \), observing reward \( r \), and arriving in new state \( s' \), update:

\[ Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \cdot \max_{a'} Q(s', a') - Q(s, a) \right] \]
Parameters
  • \( \alpha \) = learning rate. Controls how much the Q-value changes with each update. Small \( \alpha \) = slow, stable learning. Large \( \alpha \) = fast, potentially unstable learning.
  • \( \gamma \) = discount factor. Balances immediate vs future rewards. \( \gamma = 0 \) means only immediate rewards matter. \( \gamma \) close to 1 means future rewards are nearly as important.

Action Selection

Q-learning can use the same UCB-style exploration as multi-armed bandits:

\[ a^* = \arg\max_a \left[ Q(t, a) + c \cdot \sqrt{\frac{\ln t}{N_t(a)}} \right] \]

The first term exploits current knowledge; the second term encourages exploring actions that have been tried less often.

Q-Learning Algorithm

Algorithm Steps
  1. Define the environment: \( E = \langle \mathcal{S}, \mathcal{A}, a_{\text{goal}}, Q_{\text{table}} \rangle \). The Q-table stores \( Q(s, a) \) for every state-action pair, initialized to zero.
  2. Set parameters: discount factor \( \gamma \), learning rate \( \alpha \), exploration parameter \( \varepsilon \).
  3. Define state transitions: specify how the environment responds to actions (what new state results from each action).
  4. Iterate: Repeat many episodes—in each step, choose an action, observe the reward and new state, and update the Q-table using the Bellman equation above.

11. Deep Q-Network (DQN)

Definition

A Deep Q-Network replaces the Q-table with a neural network that estimates Q-values. The network takes a state as input and outputs Q-values for all possible actions.

Why DQN Exists

When the state space is very large (e.g., the pixels of a video game screen), a Q-table is impractical—it would need one entry for every possible state-action pair, which could be billions or more. A neural network can generalize across similar states, estimating Q-values for states it has never seen before. DQN made it possible for RL to solve problems with high-dimensional input like Atari games.

How DQN Works

The neural network \( Q(s, a; \theta) \) is trained to minimize the same Bellman error as tabular Q-learning:

\[ L(\theta) = \mathbb{E}\left[\left( r + \gamma \max_{a'} Q(s', a'; \theta^-) - Q(s, a; \theta) \right)^2\right] \]

where \( \theta^- \) is a periodically-updated copy of the network parameters (called the target network), which stabilizes training.

12. SARSA (State-Action-Reward-State-Action)

Definition

SARSA is a model-free RL algorithm similar to Q-learning. The name comes from the tuple it uses for each update: \( (s, a, r, s', a') \)—the current state, the action taken, the reward received, the next state, and the next action actually taken.

SARSA Update Rule
\[ Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \cdot Q(s', a') - Q(s, a) \right] \]
Key Difference from Q-Learning

Q-learning uses \( \max_{a'} Q(s', a') \)—the value of the best possible action in the next state (regardless of what action the agent actually takes). This is called off-policy learning.

SARSA uses \( Q(s', a') \)—the value of the action the agent actually takes in the next state. This is called on-policy learning.

SARSA tends to be more conservative because it accounts for the agent's actual behavior (including exploration), while Q-learning is more optimistic because it assumes optimal future actions.

13. Generative AI in the RL Context

Now we connect RL to generative AI—the technology behind models that create text, images, and other content.

Pred-AI vs Gen-AI
  • Predictive AI (Pred-AI): Deterministic. The function \( f(x) \) always gives the same output for the same input. Given a photo of a cat, a Pred-AI classifier always outputs "cat."
  • Generative AI (Gen-AI): Non-deterministic. The model can generate different outputs for the same input. Given the prompt "draw a cat," a Gen-AI model can produce many different cat images.

The key capability of Gen-AI is that it can create new content based on learned patterns—content that was never in the training data.

14. From Predictive AI to Generative AI

The Problem

Pred-AI uses a deterministic function \( y = f(x) \). To create Gen-AI, we need a stochastic (random) transformation—one that can produce different outputs for the same input.

The Solution: Add Randomness

Attach a random variable \( z \) to the input \( x \), where \( z \) is sampled from some distribution (typically a Gaussian). Since \( z \) is different each time, the output will be different each time.

The Reparameterization Trick

The challenge: if the output involves random sampling, how do we compute gradients for training? Sampling is not differentiable.

The solution is the reparameterization trick:

\[ y = \mu + \sigma \cdot z, \qquad z \sim \mathcal{N}(0, 1) \]

Instead of sampling \( y \) from \( \mathcal{N}(\mu, \sigma^2) \) directly, we sample \( z \) from the standard normal distribution and compute \( y \) as a deterministic function of \( \mu \), \( \sigma \), and \( z \).

Now \( \mu \) and \( \sigma \) are learnable parameters, and we can compute gradients with respect to them because the randomness is isolated in \( z \), which does not depend on the parameters.

Gen-AI Output

The output of a Gen-AI model \( f_\theta \) is not a single value but a probability distribution. For input \( x \), the output \( y \) has probability \( p(y|x; \theta) \). We denote the output distribution as:

\[ \omega \cong \omega(\theta, x) \]

Each time we sample from this distribution, we may get a different \( y \).

15. Training Non-Deterministic Models

Dual Training Objective

Gen-AI training has two objectives that work together:

  1. Minimize prediction error (same as Pred-AI):
    \[ \theta^* = \arg\min_\theta L(\theta) \]
    This ensures the model's outputs are realistic and accurate.
  2. Maximize expected reward via RL:
    \[ \max \;\mathbb{E}\left[\sum_t \gamma^t R(s_t, a_t)\right] \]
    This ensures the model improves through feedback, generating content that scores well according to some reward signal.

The first objective grounds the model in reality; the second objective pushes it to create better and more useful content.

16. Two Basic Generative Models

Definition

Discriminative model: Computes the conditional probability \( p(y|x) \)—the probability of label \( y \) given input \( x \). This is Pred-AI. Given an image, it outputs "cat" or "dog" with certain probabilities. It can only classify; it cannot generate new images.

Definition

Generative model: Computes the joint probability \( p(x, y) \)—the probability of both the data \( x \) and the label \( y \) occurring together. Because it models the full data distribution, it can generate new \( x \) that looks like real data.

Why This Distinction Matters

A discriminative model learns the boundary between classes. A generative model learns what the data itself looks like. Only the generative approach enables creating new data, because it understands the underlying distribution of the data, not just how to separate categories.

17. Generative Adversarial Networks (GANs)

Definition

A Generative Adversarial Network (GAN) consists of two neural networks trained together in opposition:

  • Generator \( G \): Takes random input \( z \) and creates fake data. Its goal is to produce data that is indistinguishable from real data.
  • Discriminator \( D \): Takes data (either real or fake) and classifies it as "real" or "fake." Its goal is to correctly identify fake data.
Why GANs Exist

Training a generator alone is difficult—how do you define "good output"? GANs solve this by making the discriminator act as an automatic quality judge. The generator is trained to fool the discriminator, and the discriminator is trained to not be fooled. This adversarial process drives both networks to improve, and the generator eventually produces highly realistic data.

GAN as Reinforcement Learning

RL Formulation of GAN

The Generator can be viewed as an RL agent:

  • Agent: the Generator \( G \).
  • State: the random input \( z \).
  • Action: the generated data \( x \| z \) (output conditioned on \( z \)).
  • Reward: the Discriminator's judgment—how "real" \( D \) thinks the generated data is.

The Generator uses reinforcement learning (specifically, the REINFORCE algorithm) to learn to produce data that the Discriminator classifies as real.

Training the Discriminator

Discriminator Training Steps

Freeze the Generator (its weights do not change). Then:

  1. Create a training set: \( \Omega = \Omega^+ \cup \Omega^- \), where \( \Omega^+ \) = real data (labeled 1) and \( \Omega^- \) = fake data from \( G \) (labeled 0).
  2. Feed data through the Discriminator. Classify each sample as real or fake.
  3. Compute the loss gradient (how wrong the Discriminator was).
  4. Update the Discriminator's weights to improve its classification accuracy.

The Discriminator is a standard supervised learning classifier (Pred-AI). There is nothing unusual about its training.

Training the Generator

Generator Training Steps

Freeze the Discriminator (its weights do not change). Then:

  1. Sample random input: \( s = z \sim P_\theta(z) \).
  2. Generate data and get the Discriminator's judgment: action = \( D(x \| z) \).
  3. Compute the reward \( R(x \| z, y) \) based on how "real" the Discriminator thinks the generated data is.
  4. Update the Generator's weights using the REINFORCE algorithm to increase the probability of actions (generated data) that received high rewards (fooled the Discriminator).

The Full GAN Training Algorithm

GAN Algorithm

Repeat until convergence:

  1. Step 1: Freeze \( G \). Train \( D \) for several epochs on a mix of real data and fake data from \( G \).
  2. Step 2: Freeze \( D \). Train \( G \) for several epochs using REINFORCE, with \( D \)'s judgment as the reward.

Convergence criterion: Training is complete when the Discriminator's accuracy drops to approximately 50%—meaning it can no longer distinguish real data from fake data. At this point, the Generator has learned to produce data that is indistinguishable from real data.

18. Diffusion Models

Definition

A diffusion model is a generative model that works in two phases: first, it gradually adds noise to real data until it becomes pure random noise; then, it learns to reverse this process, starting from pure noise and gradually removing noise to generate new data.

Why Diffusion Models Exist

GANs are powerful but notoriously difficult to train—the adversarial process can be unstable, and the generator can fail to learn diverse outputs (mode collapse). Diffusion models avoid these problems by using a mathematically well-defined noise-addition process that is easy to reverse. They produce higher-quality and more diverse outputs than GANs, which is why they power modern image generation systems like DALL-E and Stable Diffusion.

Forward Process (Adding Noise)

Starting with real data \( X_0 \), we add Gaussian noise over \( T \) steps. At each step, the data becomes noisier:

\[ X_t = \sqrt{1 - \beta_t} \cdot X_{t-1} + \sqrt{\beta_t} \cdot \varepsilon_t, \qquad \varepsilon_t \sim \mathcal{N}(0, I) \]
Understanding the Forward Process
  • \( \beta_t \) is the noise schedule. It controls how much noise is added at each step. \( \beta_t \) increases over time, meaning more noise is added in later steps.
  • \( \varepsilon_t \) is fresh random noise sampled at each step.
  • After many steps, \( X_T \) is approximately pure Gaussian noise—all structure from the original data has been destroyed.

The transition at each step can be expressed as a conditional Gaussian:

\[ q(X_t | X_{t-1}) = \mathcal{N}\left( X_t;\; \sqrt{\alpha_t} \cdot X_{t-1},\; (1 - \alpha_t) I \right), \qquad \alpha_t = 1 - \beta_t \]

Jumping Directly to Step \( t \)

A key mathematical result: we do not need to run all \( t \) steps sequentially. We can compute \( X_t \) directly from \( X_0 \):

Direct Computation
\[ X_t = \sqrt{\bar{\alpha}_t} \cdot X_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \bar{\varepsilon}_t \]

where:

\[ \bar{\alpha}_t = \prod_{i=1}^{t} \alpha_i \]

This is computationally efficient: during training, we can sample any time step \( t \) and compute the noisy version \( X_t \) in one step.

Backward Process (Removing Noise)

The backward process learns to reverse the forward process. Given noisy data \( X_t \), it estimates the slightly-less-noisy version \( X_{t-1} \).

Deriving the Backward Step

We need to estimate the conditional distribution:

\[ q(X_{t-1} | X_t, X_0) \]

Using Bayes' rule, this can be written as:

\[ q(X_{t-1} | X_t, X_0) = \frac{ q(X_t | X_{t-1}, X_0) \cdot q(X_{t-1} | X_0) }{ q(X_t | X_0) } \]

Since all three terms on the right are Gaussian distributions (from the forward process), the result is also a Gaussian, with computable mean and variance.

Backward Step Parameters

The backward step \( q(X_{t-1} | X_t, X_0) \) is a Gaussian with:

Variance:

\[ \sigma_q^2(t) = \frac{(1 - \alpha_t)(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} \]

Mean:

\[ \mu_q(X_t, X_0) = \frac{\sqrt{\alpha_t}(1 - \bar{\alpha}_{t-1})}{1 - \bar{\alpha}_t} X_t + \frac{\sqrt{\bar{\alpha}_{t-1}}(1 - \alpha_t)}{1 - \bar{\alpha}_t} X_0 \]

These formulas are derived entirely from the Gaussian properties of the forward process.

The Neural Network's Role

What the Network Learns

The backward process requires knowing \( X_0 \) (the original clean data), but at generation time we don't have \( X_0 \)—we're trying to create it. The solution: train a neural network \( \varepsilon_\theta(X_t, t) \) to predict the noise \( \varepsilon \) that was added to create \( X_t \) from \( X_0 \).

Once we know the predicted noise, we can estimate \( X_0 \) from the direct computation formula and then compute the backward step mean \( \mu_q \). By iterating from \( X_T \) (pure noise) back to \( X_0 \), we generate new data.

Diffusion Model: Generation Process
  1. Start with pure random noise \( X_T \sim \mathcal{N}(0, I) \).
  2. For each step \( t = T, T-1, \ldots, 1 \):
    • Use the neural network to predict the noise: \( \hat{\varepsilon} = \varepsilon_\theta(X_t, t) \).
    • Estimate \( X_0 \) from the predicted noise.
    • Compute the backward step mean \( \mu_q \) and variance \( \sigma_q^2 \).
    • Sample \( X_{t-1} \) from \( \mathcal{N}(\mu_q, \sigma_q^2 I) \).
  3. The final \( X_0 \) is the generated data.