Lesson 3 - Conditioning in SDEs

Lesson 3 - Conditioning in SDEs

How to condition your SDE on specific constraints

Optional reading for this lesson

Slides

Video (soon)

1. Conditioning - Doob’s h-transform

We saw in the last lecture that we can use the reverse SDE to generate samples from a score-based model. If we train this score-based model well eough, these samples should be indistinguishable from the true data distribution. However, in practical applications such as image generation or molecule design, we not want to generate any sample from the data distribution, but we want to generate samples that satisfy certain conditions. For example, we might want to generate an image of a cat, or we might want to generate a molecule that has a certain property. How can we do this?

1.1 Formalising conditioning

Let us again formalise this problem in the SDE framework. We have a forward SDE that describes the evolution of the data distribution to the noise distribution: dXt=f+(Xt,t)dt+σ(Xt,t)dWtdX_t = f^+(X_t, t)dt + \sigma(X_t,t)dW_t

We can then use the reverse SDE to generate samples from the noise distribution. dXt=f(Xt,t)dt+σ(Xt,t)dWˉt=[f+(Xt,t)+σ2(Xt)xlogp(Xt)]dt+σ(Xt)dWˉt\begin{align} dX_t &= f^-(X_t, t)dt + \sigma(X_t,t)d\bar{W}_t \\ &= [-f^+(X_t, t) + \sigma^2(X_t) \nabla_x log p(X_t)]dt + \sigma(X_t)d\bar{W}_t \end{align}

However, we want to generate samples from the data distribution that satisfy certain conditions. Generally, we want to satisfy so-called endpoint or hitting point conditions, meaning that we want the SDE to hit a certain point xTx_T at time TT (the end of the denoising process). In other words, given a hard constraint in the form of a point xT=zx_T = z or a constraint set xTBx_T \in B, we want to find a new SDE with p(xT=z)=1p(x_T = z) = 1 or p(xTB)=1p(x_T \in B) = 1. When we do such a conditioning, we will see that this conditioned process is again an SDE, but with a different drift term. To indicate the difference, we will use the notation fcond+f^+_{cond} for the conditioned drift term in case we condition a forward SDE and fcondf^-_{cond} for the conditioned drift term in case we condition a reverse SDE. We will also call the random variable in the conditioned SDE ZtZ_t instead of XtX_t.

Definition: Given a forward SDE dXt=f+(Xt,t)dt+σdWtdX_t = f^+(X_t, t)dt + \sigma dW_t and a hard constraint xT=zx_T = z, the conditioned forward SDE is given by

dZt=fcond+(Zt,t)dt+σdWt=[f+(Zt,t)+σ2ZtpTth(zZt)]dt+σdWt\begin{align} dZ_t &= f^+_{cond}(Z_t, t)dt + \sigma dW_t \\ &= [f^+(Z_t, t) + \sigma^2 \nabla_{Z_t}p^h_{T \mid t}(z \mid Z_t)]dt + \sigma dW_t \end{align}

where ZTδz and pTth(zZt)=pTth(zT=zZt)Z_T \sim \delta_z \text{ and } p^h_{T \mid t}(z \mid Z_t) = p^h_{T \mid t}(z_T = z \mid Z_t) is called the h-transform.

This result is known as the Doob’s h-transform and is relevant in generative modelling since different conditioning methods can be seen as different approximations to this quantity.

1.2 Example for pinned Brownian motion

Let us look at a simple example to understand this better in the context of generative modelling. We will look at the example of pinned Brownian motion, where we want to generate a Brownian motion that hits a certain point zz at time TT, i.e. ZTδ0Z_T \sim \delta_0. Let us fix this point to be z=0z = 0. Our Brownian motion starts from an arbitrary distribution X0pdata(x)X_0 \sim p_{data}(x) and evolves according to the SDE dXt=σdWtdX_t = \sigma dW_t. Applying Doob’s h-transform, we get dZt=fcond+(Zt,t)dt+σdWt=ZtTtdt+σdWt\begin{align} dZ_t &= f^+_{cond}(Z_t, t)dt + \sigma dW_t \\ &= - \frac{Z_t}{T-t} dt + \sigma dW_t \end{align}

2. Conditioning in score-based models

With this background, we can look at how conditioning is done in score-based models and how this fits in the framework of Doob’s h-transform.

Similar to how diffusion models in practice were established before people made the connection to SDEs, conditioning in score-based models was also used before the connection to Doob’s h-transform was made. However, the connection provides some cool insights into how conditioning works in score-based models and how we can improve it. For more background information and some applications you can read this paper we recently published on this.

2.1 Conditioning at Inference Time

Let us consider the time reversed SDE dXt=f(Xt)dt+σtdW~tdX_t = f^{-}(X_t) dt + \sigma_t d\tilde{W}_t and the general hard constraint (X0B)(X_0 \in B) (X0X_0 since we consider the endpoint of a reversed SDE running from TT to 00). We can then use Doob’s h-transform to get the conditioned SDE: dZt=fcond(Zt)dt+σtdW~t=[f(Zt)+σ2Ztp0th(X0BZt)]dt+σtdW~t\begin{align} dZ_t &= f^-_{cond}(Z_t)dt + \sigma_t d\tilde{W}_t \\ &= [f^-(Z_t) + \sigma^2 \nabla_{Z_t}p^h_{0 \mid t}(X_0 \in B \mid Z_t)]dt + \sigma_t d\tilde{W}_t \end{align}

The h-transform can now be decomposed into two terms via Bayes rule (see this video, minute 44:00): Xtlnp0th(X0BZt)=Xtlnpt0(ZtX0B)Xtlnpt(Zt)\nabla_{X_t} \ln p^h_{0 \mid t}(X_0 \in B \mid Z_t) = \nabla_{X_t} \ln p_{t \mid 0}(Z_t \mid X_0 \in B) - \nabla_{X_t} \ln p_{t}(Z_t)

Here, the first term (conditional score) ensures that the event is hit at the specified boundary time, while the second term (prior score) ensures it is still the time-reversal of the correct forward process.

The hard constraint X0BX_0 \in B is often also expressed as A(X0)=y\mathcal{A}(X_0) = y, where A\mathcal{A} is a known measurement operator and yy an observation. In this case, the conditioned SDE can be written as dZt=fcond(Zt)dt+σtdW~t=[f(Zt)+σ2Ztp0th(A(X0)=yZt)]dt+σtdW~t\begin{align} dZ_t &= f^-_{cond}(Z_t)dt + \sigma_t d\tilde{W}_t \\ &= [f^-(Z_t) + \sigma^2 \nabla_{Z_t}p^h_{0 \mid t}(\mathcal{A}(X_0) = y \mid Z_t)]dt + \sigma_t d\tilde{W}_t \end{align}

It is just a notational difference, and depending on the task you consider one formulation may be more natural than the other. For example, in computer vision tasks such as image reconstruction, the measurement operator is often a blurring or other noising operator, so the second formulation is more natural. In other tasks, such as molecule design, the first formulation is more natural since our constraints are often based on substructures of the molecule (we will discuss this later in more detail).

So can we just sample from Doob’s h-transform to get samples that satisfy our constraints? Unfortunately, this is not possible. To see why, let us re-express the h-transform: p0th(A(X0)=yZt)=1A(x0)=yp0t(x0Zt)dx0p^h_{0 \mid t}(\mathcal{A}(X_0) = y \mid Z_t) = \int \mathbb{1}_{\mathcal{A}(x_0) = y} p_{0 \mid t}(x_0 \mid Z_t) dx_0

where 1A(x0)=y\mathbb{1}_{\mathcal{A}(x_0) = y} is the indicator function that is 1 if the constraint is satisfied and 0 otherwise. In the case of score-based models, p0t(X0Zt)p_{0 \mid t}(X_0 \mid Z_t) is exactly the transition density of our reversed SDE as seen before. While we can sample for this distribution, we cannot easily get its value at a certain point, making the integral above hard to evaluate.

One way around this was proposed by Song 2022. Finzi 2023, Rozet & Louppe 2023 and others: use a Gaussian approximation to the transition density term: p0t(x0Zt)N(x0E[X0Xt=Zt],γt)p_{0 \mid t}(x_0 \mid Z_t) \approx \mathcal{N}(x_0 \mid \mathbb{E}[X_0 \mid X_t = Z_t], \gamma_t)

Here, E[X0Xt=Zt]\mathbb{E}[X_0 \mid X_t = Z_t] is the mean of the Gaussian approximation and just the denoised sample, i.e. we use the already trained score network. The variance σt\sigma_t is the variance of the Gaussian approximation and is a hyperparameter that has to be tuned.

where μ0t(Zt)\mu_{0 \mid t}(Z_t) and γ0t(Zt)\gamma_{0 \mid t}(Z_t) are the mean and covariance of the Gaussian approximation. This allows us to evaluate the integral above and sample from the h-transform. However, this approximation is not always accurate and can lead to poor conditioning results.

With this approximation plugged into our conditioned SDE, we get the following approximate conditioned SDE: dZt=fcond(Zt)dt+σtdW~t=[f(Zt)+σ2ZtyAE[X0Xt=Zt0]]γt2dt+σtdW~t\begin{align} dZ_t &= f^-_{cond}(Z_t)dt + \sigma_t d\tilde{W}_t \\ &= [f^-(Z_t) + \sigma^2 \nabla_{Z_t} \left\lVert y - A \mathbb{E}[X_0 \mid X_t = Z_t0] \right\rVert]^2_{\gamma_t}dt + \sigma_t d\tilde{W}_t \end{align}

Approaches using this approximation are generally called reconstruction guidance. Why is this? Let us look at the algorithm that implements this approximation:

reconstruction_guidance

The lines in grey are the ones that differ from the standard unconditional sampling algorithm.

Conditioning at inference time has the big advantage that one can just use a pre-trained unconditional model for sampling and does not need to train a separate model for conditioning. However, in some settings it has been shown thta choosing the guidance scale and thereby balancing between sample quality and fulfillment of the constraint is not trivial, leading to subpar results.

2.2 Conditioning at Training Time

Instead of enforcing the constraint during inference time, we can also learn the whole conditional score including the h-transform directly at training time. Since during training we amortise over all the constraints for each specific sample, this approach is often called amortised training. More specifically, our network approximating the conditional score is amortised over measurement operator A\mathcal{A} and observations yy instead of learning a separate network for each condition.

Proposition: Given the objective

f=arg minfEYp(yA,x0),Ap,X0pdata[0Tf(t,Xt,Y,A)Xtlnpt0(XtX0)2dt]f^* = \argmin_{f} \mathbb{E}_{Y \sim p(y \mid \mathcal{A}, x_0), \mathcal{A} \sim p, X_0 \sim p_{data}}\big[ \int_0^T \left\lVert f(t,X_t, Y, \mathcal{A}) - \nabla_{X_t} \ln p_{t \mid 0}(X_t \mid X_0) \right\rVert^2 dt \big]

the minimiser is given by the conditional score

f(t,Xt,y,A)=Xtlnpt0(XtY=y,A=A)f^*(t,X_t, y, A) = \nabla_{X_t} \ln p_{t \mid 0}(X_t \mid Y = y, \mathcal{A} = A)

We can proof this by invoking the mean squared error property of the conditional expectation we discussed before in the context of the closed-form transition densities for score-based models. This property tells us that the minimiser of the objective is given by: f(xt,y,A)=EX0Xt=xt,Y=y,A=A[Xtlnpt0(XtX0)]expectation def.=xtlnpt0(xtX0)p0t(X0xt,y,A)dX0Bayes=xtlnpt0(xtX0)pt0(xtX0)p0(X0)pt(xty,A)dX0log. deriv. trick=xtpt0(xtX0)pt0(xtX0)pt0(xtX0)p0(X0)pt(xty,A)dX0pull out denominator=1pt(xty,A)xtpt0(xtX0)pt0(xtX0)pt0(xtX0)p0(X0)dX0cancel terms=1pt(xty,A)xt[pt0(xtX0)]pt(xty,A)dX0dom. conv. theorem=1pt(xty,A)xt[pt0(xtX0)]pt(xty,A)dX0marginalise over integral=1pt(xty,A)xtpt(xty,A)log deriv. trick=xtlnpt(xty,A)\begin{align} f(x_t, y, A) &= \mathbb{E}_{X_0 \mid X_t=x_t, Y = y, \mathcal{A} = A}[\nabla_{X_t} \ln p_{t \mid 0}(X_t \mid X_0)] \hspace{10px} \mid \text{expectation def.} \\ &= \int \nabla_{x_t} \ln p_{t \mid 0}(x_t \mid X_0) p_{0 \mid t}(X_0 \mid x_t, y, A) dX_0 \hspace{10px} \mid \text{Bayes} \\ &= \int \nabla_{x_t} \ln p_{t \mid 0}(x_t \mid X_0) \frac{p_{t \mid 0}(x_t \mid X_0) p_{0}(X_0)}{p_{t}(x_t \mid y, A)} dX_0 \hspace{10px} \mid \text{log. deriv. trick} \\ &= \int \frac{\nabla_{x_t} p_{t \mid 0}(x_t \mid X_0)}{p_{t \mid 0}(x_t \mid X_0)} \frac{p_{t \mid 0}(x_t \mid X_0) p_{0}(X_0)}{p_{t}(x_t \mid y, A)} dX_0 \hspace{10px} \mid \text{pull out denominator} \\ &= \frac{1}{p_{t}(x_t \mid y, A)} \int \frac{\nabla_{x_t} p_{t \mid 0}(x_t \mid X_0)}{p_{t \mid 0}(x_t \mid X_0)} p_{t \mid 0}(x_t \mid X_0) p_{0}(X_0) dX_0 \hspace{10px} \mid \text{cancel terms} \\ &= \frac{1}{p_{t}(x_t \mid y, A)} \int \nabla_{x_t} [p_{t \mid 0}(x_t \mid X_0)] p_{t}(x_t \mid y, A) dX_0 \hspace{10px} \mid \text{dom. conv. theorem} \\ &= \frac{1}{p_{t}(x_t \mid y, A)} \nabla_{x_t} \int [p_{t \mid 0}(x_t \mid X_0)] p_{t}(x_t \mid y, A) dX_0 \hspace{10px} \mid \text{marginalise over integral} \\ &= \frac{1}{p_{t}(x_t \mid y, A)} \nabla_{x_t} p_{t}(x_t \mid y, A) \hspace{10px} \mid \text{log deriv. trick} \\ &= \nabla_{x_t} \ln p_{t}(x_t \mid y, A) \end{align}

So we can see that the minimiser of the objective is indeed the conditional score. This means that we can train the conditional score by minimising the objective above. Let’s look at a practical algorithm that does this:

amortised_training

With these results in mind, we can look at some of the conditioning methods that have been proposed for score-based models before the connection to SDEs and Doob’s h-transform was made and see how they fit into the picture.

A common one that people refer to is classifier guidance.

Another one that often comes up is classifier-free guidance.

classifier_free_guidance

2.3 Soft constraints

Credits

Much of the logic in the lecture is based on the Oksendal as well as the Särkkä & Solin book. Title image from Wikipedia.