Lesson 3 - Conditioning in SDEs
How to condition your SDE on specific constraints
Optional reading for this lesson
Slides
Video (soon)
1. Conditioning - Doob’s h-transform
We saw in the last lecture that we can use the reverse SDE to generate samples from a score-based model. If we train this score-based model well eough, these samples should be indistinguishable from the true data distribution. However, in practical applications such as image generation or molecule design, we not want to generate any sample from the data distribution, but we want to generate samples that satisfy certain conditions. For example, we might want to generate an image of a cat, or we might want to generate a molecule that has a certain property. How can we do this?
1.1 Formalising conditioning
Let us again formalise this problem in the SDE framework. We have a forward SDE that describes the evolution of the data distribution to the noise distribution:
We can then use the reverse SDE to generate samples from the noise distribution.
However, we want to generate samples from the data distribution that satisfy certain conditions. Generally, we want to satisfy so-called endpoint or hitting point conditions, meaning that we want the SDE to hit a certain point at time (the end of the denoising process). In other words, given a hard constraint in the form of a point or a constraint set , we want to find a new SDE with or . When we do such a conditioning, we will see that this conditioned process is again an SDE, but with a different drift term. To indicate the difference, we will use the notation for the conditioned drift term in case we condition a forward SDE and for the conditioned drift term in case we condition a reverse SDE. We will also call the random variable in the conditioned SDE instead of .
Definition: Given a forward SDE and a hard constraint , the conditioned forward SDE is given by
where is called the h-transform.
This result is known as the Doob’s h-transform and is relevant in generative modelling since different conditioning methods can be seen as different approximations to this quantity.
1.2 Example for pinned Brownian motion
Let us look at a simple example to understand this better in the context of generative modelling. We will look at the example of pinned Brownian motion, where we want to generate a Brownian motion that hits a certain point at time , i.e. . Let us fix this point to be . Our Brownian motion starts from an arbitrary distribution and evolves according to the SDE . Applying Doob’s h-transform, we get
2. Conditioning in score-based models
With this background, we can look at how conditioning is done in score-based models and how this fits in the framework of Doob’s h-transform.
Similar to how diffusion models in practice were established before people made the connection to SDEs, conditioning in score-based models was also used before the connection to Doob’s h-transform was made. However, the connection provides some cool insights into how conditioning works in score-based models and how we can improve it. For more background information and some applications you can read this paper we recently published on this.
2.1 Conditioning at Inference Time
Let us consider the time reversed SDE and the general hard constraint ( since we consider the endpoint of a reversed SDE running from to ). We can then use Doob’s h-transform to get the conditioned SDE:
The h-transform can now be decomposed into two terms via Bayes rule (see this video, minute 44:00):
Here, the first term (conditional score) ensures that the event is hit at the specified boundary time, while the second term (prior score) ensures it is still the time-reversal of the correct forward process.
The hard constraint is often also expressed as , where is a known measurement operator and an observation. In this case, the conditioned SDE can be written as
It is just a notational difference, and depending on the task you consider one formulation may be more natural than the other. For example, in computer vision tasks such as image reconstruction, the measurement operator is often a blurring or other noising operator, so the second formulation is more natural. In other tasks, such as molecule design, the first formulation is more natural since our constraints are often based on substructures of the molecule (we will discuss this later in more detail).
So can we just sample from Doob’s h-transform to get samples that satisfy our constraints? Unfortunately, this is not possible. To see why, let us re-express the h-transform:
where is the indicator function that is 1 if the constraint is satisfied and 0 otherwise. In the case of score-based models, is exactly the transition density of our reversed SDE as seen before. While we can sample for this distribution, we cannot easily get its value at a certain point, making the integral above hard to evaluate.
One way around this was proposed by Song 2022. Finzi 2023, Rozet & Louppe 2023 and others: use a Gaussian approximation to the transition density term:
Here, is the mean of the Gaussian approximation and just the denoised sample, i.e. we use the already trained score network. The variance is the variance of the Gaussian approximation and is a hyperparameter that has to be tuned.
where and are the mean and covariance of the Gaussian approximation. This allows us to evaluate the integral above and sample from the h-transform. However, this approximation is not always accurate and can lead to poor conditioning results.
With this approximation plugged into our conditioned SDE, we get the following approximate conditioned SDE:
Approaches using this approximation are generally called reconstruction guidance. Why is this? Let us look at the algorithm that implements this approximation:
The lines in grey are the ones that differ from the standard unconditional sampling algorithm.
Conditioning at inference time has the big advantage that one can just use a pre-trained unconditional model for sampling and does not need to train a separate model for conditioning. However, in some settings it has been shown thta choosing the guidance scale and thereby balancing between sample quality and fulfillment of the constraint is not trivial, leading to subpar results.
2.2 Conditioning at Training Time
Instead of enforcing the constraint during inference time, we can also learn the whole conditional score including the h-transform directly at training time. Since during training we amortise over all the constraints for each specific sample, this approach is often called amortised training. More specifically, our network approximating the conditional score is amortised over measurement operator and observations instead of learning a separate network for each condition.
Proposition: Given the objective
the minimiser is given by the conditional score
We can proof this by invoking the mean squared error property of the conditional expectation we discussed before in the context of the closed-form transition densities for score-based models. This property tells us that the minimiser of the objective is given by:
So we can see that the minimiser of the objective is indeed the conditional score. This means that we can train the conditional score by minimising the objective above. Let’s look at a practical algorithm that does this:
With these results in mind, we can look at some of the conditioning methods that have been proposed for score-based models before the connection to SDEs and Doob’s h-transform was made and see how they fit into the picture.
A common one that people refer to is classifier guidance.
Another one that often comes up is classifier-free guidance.
2.3 Soft constraints
Credits
Much of the logic in the lecture is based on the Oksendal as well as the Särkkä & Solin book. Title image from Wikipedia.