This blog is a highly detailed walkthrough on how Latent Diffusion Models, like Stable Diffusion. It’s catered towards folks who want to learn in complete detail how image generation works, we’ll be covering the intuition and the math behind the models.
Note: The prerequisite for understanding this is knowing how transformers work. If you don’t know how transformers work, I would highly suggest reading up on them first here.
What is diffusion?
The first thing that one has to understand is, especially considering that this is what will help you understand the concept of diffusion models in the most intuitive way possible, is first just wtf is diffusion in the first place?
Remember in 6th grade biology you had studied the concept of diffusion where molecules in a solution sort of just disperse throughout?
It was particles diffusing in a medium.. I want you to hold this in your mind as I explain the concept of what actually Stable Diffusion works like.
Stable Diffusion: The Big Picture
Stable Diffusion, is a type of Latent Diffusion Model (LDM). There's 2 parts to this, the latent and the diffusion part. (such machine learning, much wow). At a high level:
- The latent part refers to a Variational Autoencoder (VAE): A separate model whose sole job is to compress input images - whether that be images from our dataset, or the pure noise during inference - into abstract representation for our diffusion model to work on
- The diffusion part refers to the Denoising Diffusion Probabilistic Model (DDPM): Quite the mouthful, but this is where the magic happens, where we convert noise to an image.
We’ll get into the specific workings of each model in the following sections.
The counter-intuitive part.
Intuitively one would imagine if some kind of "diffusion" is to be applied on an image, all the pixels would sort of ”meld together” and form pure noise.
But what if the process of going from noise to well formed image is STILL diffusion?!
And the math checks out because NEURAL NETWORKS LOL
From nothing to something
In this blog, we'll discuss how every little bit of detail inside a latent diffusion model works, how we go from nothing to something
For context, right now we are dealing with just the process of converting a pure noise image to an image that is not noise. The familiar text-to-image process we're familiar with comes at a later stage.
The Anatomy of an LDM: VAEs and DDPMs
The anatomy of a Latent Diffusion Model looks like this:
The "Latent" Part: Variational Autoencoder (VAE)
We won't be focusing too much on the Variational Autoencoder part, but here's all you need to know.
The VAE is only responsible for taking a really large image, , and converting it to a latent space representation. This converts it to an abstract representation of the image, denoted by . The specific abstractions are learned.
We do this because the diffusion part (DDPM), is quite math intense, so working with smaller images makes it more computationally efficient
In training, the VAE converts the dataset images into latent representations.
- The DDPM TRAINS on these latent representations of our dataset.
- Inference starts in the latent space with a noisy latent representation.
Example of VAE:
The "Diffusion" Part: Denoising Diffusion Probabilistic Model (DDPM)
Here's where the fun begins.
If you wanna make sure the process of destroying an image, and reconstructing it from noise remain mathematically equivalent, your noise needs to be Gaussian noise. I. e. The RGB values of each pixel belong to the standard normal distribution
The reason behind keepign it Gaussian noise will become clearer once we understand the intuition of how things work.
Here's where we get into the maths of the DDPM. This is the big man who generates our images.
Understanding how the training process works makes understanding the inference trivial
DDPM Deep Dive
The training process consists of 2 phases: a forward phase (destructive) and a reverse phase (constructive). The forward phase is predetermined, whereas the reverse phase has the core neural network kicking in.
Forward Process
During the forward phase, the model works to progressively destroy latent representations (derived from its dataset images) by converting them to pure Gaussian noise. The below example shows the LATENT space representation our original image getting destroyed (not the original!)
The main things we keep in our mind is we maintain a variance schedule. This is a predefined list of size (usually ); it starts small, like 0.0001, and eventually scales to 0.02.
This is how much variance we add at each step of the destruction process. Each goes through this destruction process with steps. (1000 noising steps).
For each latent representation we do:
If you’re unfamiliar with notation, this is read as “The probability distrubution of , the intermediary noising stage, depends on and is given by a normal distribution - whose mean is and variance is (the variance from the noise schedule we defined earlier. The I just means that the noise we add to each pixel has the same variance)
i. e. :
To understand what this means, it says: The latent representation at timestamp t is given by the latent representation at t-1 scaled down by a factor of , and having variance . Here, represents the probability distribution
This is like reducing the "amplitude" of the latent, making it duller. The variance factor adds a controlled amount of noise. This establishes a Markov chain (each element of the process, in our case a latent is only dependent on the previous one ), allowing us to in the future meaningfully apply a reverse process.
The scaling factor of is taken since this is a multiplicative process, where at each step we are multiplying, the root remains to ensure that we don't diminish the latent too quickly.
Most critically,
and if = 1, then the total variance at any given timestep = 1.
This is a super important part because this is what allows us to ensure that the process of generating an image from noise is the reverse of what we do in the forward process.
Had we not done this, this would lead to exploding (too large) or vanishing (too small) variances and training a denoising model is near impossible
However, since having to calculate all previous latents before step t by sequentially applying the formula till is unnecessary, we can also write it as:
Which is trivial since you're just recurisvely replacing the value of into , repeating till and multiplying.
So far we have found given data we already have in a sort of predictable way.
Reverse Process
This part gets math intensive.
If you wanted to model the reverse process, you would iteratively find what the image looked like at , given a particular and - i. e. the current noisy image, and what it actually was.
In an ideal world, if you already knew what the latent originally was before it got destroyed by noise, and knew what it currently looks like with noise at t timestamp (zt), your best educated guess to find using , , if you somehow knew , would be:
Where is a weighted average between zt and z0. This term represents how far along we are in our variance schedule, and what mean we should take.
It is defined by:
Where coefficient of defines what percent of to take and coefficient of defines what percent of to take, adding them to form an average to which we add our variance.
And
Where = cumulated noise till , and = cumulated noise till . This acts as kinda like some "wiggle room".
I.e. when we are closer to , the data is so noisy that we don’t know much about the underlying structure yet. When we sample , we need to allow a lot of randomness (large ) to explore possible less-noisy states.
This large uncertainty is added as noise. At lower , we’ve already removed a lot of noise, and our guess is more reliable. We don’t need as much randomness when sampling , so is reduced.
But in practicality we will never have access to . It is not possible to attempt to model all pairs during inference and know all possible s.
So what to predict with our neural net?
This is where our machine learning kicks in. Since in practicality we can never use the equation to find the exact value, our goal is to somehow get a neural network to predict . We cannot directly predict , but we can apply a little trick.
We know
Rearranging, we get
Hence we train our model to predict - an approximation of the total CUMULATIVE NOISE added up till .
is the true noise we added, and is our neural network's prediction of that noise
Plugging this PREDICTED into our actual , we get:
This is an estimate of the original clean latent , derived from the current noisy latent and the network's prediction of the noise at timestep .
Recall our original ideal reverse process equation
We plug in our predicted
This can be simplified, though it is a bit of an involved algebraic process. It’s not required to know, but if you’re interested, here’s a link to the full derivation.
This equation simplifies to:
Since is the cumulative noise from , our coefficient for shrinks it down from total noise till to only the noise that was added to to go to . It’s like saying, “You predicted all the noise, but I only need the slice tied to .”
And thus finally, since we know what our model has to predict,
We define the loss function as:
→ Find the difference between actual noise, and predicted noise. The greater the difference, higher the loss.
Which begs the question:
How do we train a model to look at a latent representation , and a given timestamp (how far along the variance schedule we are) and make a prediction on how much total noise has been added so far?
Enter the U-Net
What is a U-net?
The U-Net is an architecture that was used in medical imaging for anomaly detection. If you have the CT scan of a patient's brain and need a machine learning model to look at it, recognise some oddities - say, a big white patch indicating a tumor - you would use this U-Net architecture.
In our case rather than trying to recognise tumors or other medical anomalies, we repurpose the existing model with minor modifications to recognise patches of noise instead. Thus the model can see "oh this latent representation () looks particularly noisy, given we are at timestamp , here's my best guess as to how much noise has been added so far".
And since during the training process we have the exact data on how much noise has been added so far, we are able to find the loss. Repeating this process for each timestamp for each latent representation eventually allows us to train the U-Net.
How does the U-Net work for Stable Diffusion?
- Input to the U-Net:
- A noisy latent image representation () at a particular timestep : Our blurry, staticky version of what will eventually become a clear image.
- The timestep itself (usually encoded and added to the U-Net). This tells the U-Net "how noisy" it should expect to be.
- Training: t's trained to predict the noise () that was added to a clean to create a
- Output (): It's the U-Net's approximation of the true noise
How Does the U-Net Predict Noise?
This is where the U-Net architecture shines, even though its original use was for segmentation:
- Encoder (Downsampling Path):
- The input noisy latent goes through a series of convolutional layers, activation functions, and downsampling operations (here, maxpooling.
- What it's doing: It's trying to learn features from the noisy input at different scales. At early layers, it might pick up on fine-grained noise patterns. As it goes deeper and the feature maps get smaller, it's forced to learn more abstract representations of how the underlying (but obscured) image structures are corrupted by noise. It's learning "what noise looks like when superimposed on various image features."
- Bottleneck:
- The most compressed representation. The model now has high-level features from the noisy image.
- Important note: Right now, we are still discussing unconditional image generation - i. e. simply sampling from the probability distribution. When dealing with conditional image generation, i. e. with text prompts. This is where we will inject our text prompt. We’ll explore this in depth soon.
- Decoder (Upsampling Path):
- This path takes the features from the bottleneck (and the skip connections) and progressively upsamples them, using convolutional layers and activation functions.
- Skip Connections are Key: These connections bring feature information directly from the encoder layers to their corresponding decoder layers - essentially by concatenating them and convolving over these concatenations. This is vital because:
- The encoder captured details about the noise patterns at different scales.
- The decoder needs this fine-grained information to accurately reconstruct a "noise map" () that has the correct spatial structure - “locations of noise” type.
- What it's doing: The decoder is essentially learning to "paint" the noise. Given the learned features (which represent "this is what a noisy edge looks like," or "this is what noisy flat texture looks like") and the guidance from the text prompt, it constructs the tensor. It's reconstructing the noise pattern, not the image itself.
- Training is How it Learns:
- During training, the U-Net is shown countless pairs of:
- A known noisy latent (created by taking a clean and adding a known amount of true Gaussian noise ).
- The actual Gaussian noise that was added.
- The U-Net predicts .
- A loss function (like Mean Squared Error) compares the U-Net's predicted noise with the true noise .
- The U-Net's weights are adjusted through backpropagation to make its prediction closer to the true
- Over many examples, it learns the complex relationship between how a latent image looks at a certain noise level , and what the corresponding noise that produced it must have been.
However, right now, this setup generates images without any control. It’s just random samples from the training data distribution. To make it text-conditioned, we need to guide the denoising process with text.
To do this, our approach for calculating epsilon using and is not sufficient; we must also condition it to use a text embedding , which comes from our prompt.
Conditioned Output: How text is turned to images.
The embedding c is generated by passing a text prompt through a pre-trained encoder (e.g. BERT, CLIP, etc.). This transforms the prompt into a shape tensor ( = batch, = token count, = dimsize).
This works by applying cross attention within the U-Net's bottleneck. Recall the attention formula is:
Where comes from the latent space input to the U-Net. and come from the text embeddings. The U-Net features are flattened from to .
We do this mainly as self-attention for 3D tensors already has time complexity.
In 4D attention, the time complexity balloons to - which is wildly compute intensive. Hence we flatten U-Net features.
Queries (Q):
Input: UNet features .
Projection: Linear layer of shape .
Matrix Multiplication: .
Output: .
Keys (K):
Input: Text embeddings .
Projection: Linear layer of shape .
Matrix Multiplication: .
Output: .
Values (V):
Input: Text embeddings .
Projection: Linear layer of shape .
Matrix Multiplication: .
Output: .
Scale, softmax, then. We compute the cross-attention using the standard formula:
Multiply by V:
.
Matrix Multiplication: .
This weights the text embeddings based on the attention scores.
Output: , representing text-conditioned features for each spatial position.
Step 5: Integrating the Output
The attention output is then:
- Reshaped to after projection,
- Added to or concatenated with the original UNet features, depending on your specific implementation.
Then continue on with normal U-Net proceedings.
Example Denoising Process:
Suppose our prompt is: ”A photo of an astronaut on the moon”
The process denoises:
The output of the DDPM ( is:
Final VAE Decoder output after operating on latent output:
yay image generation :)
Wrap-up and Review
And that's it! That about wraps up everything you need to know on how LDM models work! Check out these flowcharts if you wanna review your data flows.
Improvements
We’ve only trained the model to learn the average in this case, many improvements can be made:
- Predict variance: Current models typically fix variance or ignore it. Predicting both mean and variance could improve image quality. Refer to: Improved Denoising Diffusion Probabilistic Models
- Introduce deterministic behaviours: Since DDPMs are stochastic by nature, introducing determinism can significantly speed up sampling Refer to: Denoising Diffusion Implicit Models
- Latent Consistency models: Ditch the markov chain, predict the image from noise in one shot. Refer to: Consistency Models Author’s note: Naklecha has one of the most interesting blogs on this. Recommend checking out https://www.aaaaaaaaaa.org/lcm for this. (yes this article is inspired by him :))
Final Words
Hopefully you enjoyed reading this as much as I did creating it. If you've made it here, it makes me very happy to know I was of help to you :)
If you're interested in more of this, or wanna support my work, follow me on Twitter!