Lesson 2 - SDEs and how to manipulate them

Lesson 2 - SDEs and how to manipulate them

How to reverse and condition SDEs

Optional reading for this lesson

Slides

Video (soon)

In this lecture we continue with our discussion of SDEs and how to manipulate them. Specifically, we will look how we can describe the evolution of the probability density of an SDE via the Fokker-PLanck Equation, how we can reverse SDEs via Nelsons Duality Formula and how we can use these ideas to build generative models such as score-based or flow matching models.

1. The Fokker-Planck Equation

In the first lecture we saw that while the solution of an ODE is just a deterministic function, the solution of an SDE is a stochastic process. We can describe this process by the SDE solution, but as a stocastic process it also has properties like a probability distribution and statistics. In this first part, we will therefore look at the Fokker-Planck Equation (FPE), which describes the evolution of the probability density of an SDE and look at the FPE for some common SDEs.

The FPE has many names depending on the field you work in. In physics it is often called the Smoluchowski equation, in probability theory it is called the Kolmogorov forward equation and in statistics it is called the Fokker-Planck-Kolmogorov equation. In this lecture we will use the term Fokker-Planck Equation (FPE).

Definition: The Fokker-Planck Equation (FPE) describes the evolution of the probability density of an SDE. For a general SDE of the form dXt=μ(Xt,t)dt+σ(Xt,t)dWtdX_t = \mu(X_t, t)dt + \sigma(X_t, t)dW_t, it is given by

tp(x,t)=x[μ(x,t)p(x,t)]+2x2[D(x,t)p(x,t)]\frac{\partial}{\partial t} p(x,t) = - \frac{\partial}{\partial x}[\mu(x,t)p(x,t)] + \frac{\partial^2}{\partial x^2}[D(x,t)p(x,t)]

where p(x,t)p(x,t) is the probability density of the SDE at time tt and xx and D(x,t)=σ2(Xt,t)2D(x,t) = \frac{\sigma^2(X_t,t)}{2} is defined as the diffusion coefficient.

To emphasise that there is a separate probability density for each time step, we will sometimes write the probability density p(x,t)p(x,t) as pt(x)p_t(x) instead, where xx is the value of the SDE at time tt. The FPE describes how this probability density evolves over time.

We will first derive the FPE for the special case of Brownian motion, then derive this general formula here and then look at the FPE for some common SDEs.

For completeness, the FPE for a general SDE of the form dXt=μ(Xt,t)dt+σ(Xt,t)dWtdX_t = \mu(X_t, t)dt + \sigma(X_t, t)dW_t is given by tp(x,t)=i=1dxi[μi(xi,t)p(x,t)]+i,j=1Nxi,xj[σσij(x,t)p(x,t)]\partial_t p(x, t) = - \sum_{i=1}^d \partial_{x_i}[\mu_i(xi,t)p(x,t)] + \sum_{i,j=1}^N \partial_{x_i,x_j}[\sigma \sigma_{ij}(x,t)p(x,t)]

1.1 FPE for Brownian Motion

Let’s start with the special case of Brownian motion, a stochastic process we discussed in the first lecture. For generality, let us denote the stochastic process as XtX_t, i.e. setting Xt=BtX_t = B_t where BtB_t is the Brownian motion.

Our SDE for Brownian motion is therefore (surprise, surprise) given by dXt=dBtdX_t = dB_t.

Let us look at a twice continously differentiable, arbitrary function f(Xt)f(X_t) with compact support and apply Ito’s lemma to it. We get df(Xt)=tf(Xt)dt+xf(Xt)dXt+12xxf(Xt)(dXt)2=xf(Xt)dBt+12xxf(Xt)dtE[df(Xt)]=E[xf(Xt)dBt+12xxf(Xt)dt]E[dBt]=0=12E[xxf(Xt)]dtuse dom. convergence theoremddtE[f(Xt)]=12E[xxf(Xt)]rewrite E and write xxf(x)=fxx(x)ddtf(x)p(x,t)dx=12fxx(x)p(x,t)dxintegrate RHS by parts=12[fxx(x)p(x,t)x=x=fx(x)p(x,t)xdx]compact support: first term = 0=12fx(x)p(x,t)xdxintegrate RHS by parts again=12[fx(x)p(x,t)xx=x=f(x)2p(x,t)x2dx]compact support: first term = 0=12f(x)2p(x,t)x2dxpull derivative inside integral on LHSf(x)p(x,t)tdx=12f(x)2p(x,t)x2dxregroup termsf(x)p(x,t)t122p(x,t)x2f(x)dx=0f is arbitraryp(x,t)t122p(x,t)x2=0\begin{align} df(X_t) &= \partial_t f(X_t) dt + \partial_x f(X_t) dX_t + \frac{1}{2} \partial_{xx} f(X_t) (dX_t)^2 \\ &= \partial_x f(X_t) dB_t + \frac{1}{2} \partial_{xx} f(X_t) dt \\ \mathbb{E}[df(X_t)] &= \mathbb{E}[\partial_x f(X_t) dB_t + \frac{1}{2} \partial_{xx} f(X_t) dt] \hspace{10px} \mid \mathbb{E}[dB_t] = 0\\ &= \frac{1}{2} \mathbb{E}[\partial_{xx} f(X_t)] dt \hspace{10px} \mid \text{use dom. convergence theorem}\\ \frac{d}{dt} \mathbb{E}[f(X_t)] &= \frac{1}{2} \mathbb{E}[\partial_{xx} f(X_t)] \hspace{10px} \mid \text{rewrite } \mathbb{E} \text{ and write } \partial_{xx} f(x) = f_{xx}(x)\\ \frac{d}{dt} \int_{-\infty}^{\infty} f(x) p(x,t) dx &= \frac{1}{2} \int_{-\infty}^{\infty} f_{xx}(x) p(x,t) dx \hspace{10px} \mid \text{integrate RHS by parts}\\ &= \frac{1}{2} [f_{xx}(x)p(x,t)\Big|_{x=-\infty}^{x=\infty} - \int_{-\infty}^{\infty} f_x(x)\frac{\partial p(x,t)}{\partial x}dx] \hspace{10px} \mid \text{compact support: first term = 0}\\ &= - \frac{1}{2} \int_{-\infty}^{\infty} f_x(x)\frac{\partial p(x,t)}{\partial x}dx \hspace{10px} \mid \text{integrate RHS by parts again}\\ &= - \frac{1}{2} [f_{x}(x)\frac{\partial p(x,t)}{\partial x}\Big|_{x=-\infty}^{x=\infty} - \int_{-\infty}^{\infty} f(x)\frac{\partial^2 p(x,t)}{\partial x^2}dx] \hspace{10px} \mid \text{compact support: first term = 0}\\ &= \frac{1}{2} \int_{-\infty}^{\infty} f(x)\frac{\partial^2 p(x,t)}{\partial x^2}dx \hspace{10px} \mid \text{pull derivative inside integral on LHS}\\ \int_{-\infty}^{\infty} f(x) \frac{\partial p(x,t)}{\partial t} dx &= \frac{1}{2} \int_{-\infty}^{\infty} f(x)\frac{\partial^2 p(x,t)}{\partial x^2}dx \hspace{10px} \mid \text{regroup terms}\\ \int_{-\infty}^{\infty} f(x) \frac{\partial p(x,t)}{\partial t} - \frac{1}{2} \frac{\partial^2 p(x,t)}{\partial x^2} f(x) dx &= 0 \hspace{10px} \mid \text{f is arbitrary}\\ \frac{\partial p(x,t)}{\partial t} - \frac{1}{2} \frac{\partial^2 p(x,t)}{\partial x^2} &= 0\\ \end{align}

This is the FPE for Brownian motion and known as the Diffusion equation. It looks very similar to the heat equation and can be derived from the general FPE formula by setting μ(x,t)=0\mu(x,t) = 0 and σ(x,t)=1\sigma(x,t) = 1.

1.2 FPE for General SDEs

Now let’s consider how we derive the FPE for a general SDE. The procedure is very similar to the derivation for Brownian motion, but we have to be a bit more careful since now we have a drift term μ(x,t)\mu(x,t) and a non-constant diffusion term σ(x,t)\sigma(x,t).

1.3 FPE for GBM and OU

We can now use the general FPE formula to derive the FPE for some common SDEs. We will not derive them here from scratch but start from the formula for general SDEs from the last section.

For the zero-centered Ornstien-Uhlenbeck process dXt=θXtdt+σdBtdX_t = -\theta X_t dt + \sigma dB_t, we have μ(x,t)=θx\mu(x,t) = - \theta x and σ(x,t)=σ\sigma(x,t) = \sigma. Plugging this into the general FPE formula, we get p(x,t)t=θxxp(x,t)+12σ22p(x,t)x2\frac{\partial p(x,t)}{\partial t} = - \theta \frac{\partial}{\partial x} x p(x,t) + \frac{1}{2} \sigma^2 \frac{\partial^2 p(x,t)}{\partial x^2}

For the Geometric Brownian Motion dXt=μXtdt+σXtdBtdX_t = \mu X_t dt + \sigma X_t dB_t, we have μ(x,t)=μx\mu(x,t) = \mu x and σ(x,t)=σx\sigma(x,t) = \sigma x. Plugging this into the general FPE formula, we get p(x,t)t=μxxp(x,t)+12σ22p(x,t)x2x2\frac{\partial p(x,t)}{\partial t} = - \mu \frac{\partial}{\partial x} x p(x,t) + \frac{1}{2} \sigma^2 \frac{\partial^2 p(x,t) x^2}{\partial x^2}

2. Time Reversal - Nelson, Anderson and co

In the first lecture we saw that the solution of an SDE is a stochastic process. We can describe this process by the SDE solution, but as a stocastic process it also has properties like a probability distribution and statistics. We now looked at the Fokker-Planck Equation (FPE), which describes the evolution of the probability density of an SDE.

This is all very nice, but the key insight of diffusion models is that we can reverse an SDE. The results leading to this conclusion are all decades old (see Anderson 1982, Haussmann & Pardoux 1986, Nelson 1967 or Foellmer 1987), but have only recently been applied in the area of generative modelling. In this section we will look at how we can reverse an SDE and what this means for the FPE.

2.1 A discrete time “heuristic” sketch

Let’s say we have a SDE we want to reverse and know its FPE. We can see the probability densities evolving over time as a joint distribution of the values at all times. If we consider two time points, tt and t+δt+\delta, we can factor the joint distribution into conditional probabilities in two ways via the chain rule of probability: pt,t+δ(x,y)=pt+δt(yx)pt(x)=ptt+δ(xy)pt+δ(y)p_{t, t+\delta}(x, y) = p_{t+\delta \mid t}(y \mid x) p_t(x) = p_{t \mid t+\delta}(x \mid y) p_{t+\delta}(y)

where we denote a random variable at time tt as xx and a random variable at time t+δt+\delta as yy.

This is exactly true when pp is the exact solution of the FPE and therefore the different conditional probabilities are the exact transition densities of our SDE.

However, let us now do a small hack: we discretise time via the Euler-Maruyama method and approximate the transition densities via the Euler-Maruyama transition densities. As a reminder, an Euler-Maruyama discretisation step for our typical SDE dXt=μ(t,Xt)dt+σ(t,Xt)dWtdX_t = \mu(t, X_t) dt + \sigma(t, X_t) dW_t after partitioning it into nn discrete intervals looks like Xn+1=Xn+μ(Xn,n)δt+σ(Xn,n)δWnX_{n+1} = X_{n} + \mu(X_{n}, n) \delta t + \sigma(X_{n}, n) \delta W_{n}

If we assume this discretisation, the transition density equality we wrote down above is not valid anymore except for the limit δ0\delta \rightarrow 0. In this approximation, the forward transition density is aGaussian that depends on the drift and diffusion term in the following way: pt+δt(yx)=N(yx+f+(x)δ,δσ2(x))\begin{align} p_{t+\delta \mid t}(y \mid x) &= \mathcal{N}(y \mid x + f^+(x) \delta, \delta \sigma^2(x)) \end{align}

where we denote the drift term at time tt as f+(x)=μ(x,t)f^+(x) = \mu(x, t) and ignore the time dependency of the drift and diffusion terms for simplicity.

We now want to get an expression for the reverse transition density ptt+δ(xy)p_{t \mid t+\delta}(x \mid y). Rearranging the chain rule decomposition above, we get ptt+δ(xy)=pt+δt(yx)pt(x)pt+δ(y)p_{t \mid t+\delta}(x \mid y) = p_{t+\delta \mid t}(y \mid x) \frac{p_{t}(x)}{p_{t+\delta}(y)}

We can now use Taylor’s theorem to expand the marginal densities ptp_t around yy. We expand not ptp_t directly, but exp(logpt)exp(\log p_t), since this is easier to work with. This results in ptt+δ(xy)=pt+δt(yx)pt(y)e(xy)Tylnpt(y)+O(δ2)pt+δ(y)\begin{align} p_{t \mid t+\delta}(x \mid y) = p_{t+\delta \mid t}(y \mid x) \frac{p_t(y)e^{(x-y)^T \nabla_y \ln p_t(y) + \mathcal{O}(\delta^2)}}{p_{t+\delta}(y)} \end{align}

If we now make a Lipschitz assumption of the form lnpt(x)lnps(x)=O(ts2)\mid \ln p_t(x) - \ln p_s(x) \mid = \mathcal{O}(\mid t-s \mid^2), we can absorb the ratio of the marginal densities into the O(δ2)\mathcal{O}(\delta^2) term and get ptt+δ(xy)=pt+δt(yx)e(xy)Tylnpt(y)+O(δ2)\begin{align} p_{t \mid t+\delta}(x \mid y) = p_{t+\delta \mid t}(y \mid x) e^{(x-y)^T \nabla_y \ln p_t(y) + \mathcal{O}(\delta^2)} \end{align}

We can now insert our Euler-Maruyama transition density for pt+δt(yx)p_{t+\delta \mid t}(y \mid x), regroup the terms and complete the square to yield ptt+δ(xy)=N(yx+f+(x)δ,δσ2(x))e(xy)Tylnpt(y)+O(δ2)=e(xy)Tylnpt(y)+O(δ2)=ex(yf+(y)δ+σ2ylnpt(y)δ)σ2δ+O(δ2)2πδd/2σd\begin{align} p_{t \mid t+\delta}(x \mid y) &= \mathcal{N}(y \mid x + f^+(x) \delta, \delta \sigma^2(x)) e^{(x-y)^T \nabla_y \ln p_t(y) + \mathcal{O}(\delta^2)} \\ &= e^{(x-y)^T \nabla_y \ln p_t(y) + \mathcal{O}(\delta^2)} \\ &= \frac{e^{- \frac{\left\lVert x - (y- f^+(y)\delta + \sigma^2 \nabla_y \ln p_t(y) \delta) \right\rVert}{\sigma^2 \delta}+ \mathcal{O}(\delta^2)}}{\sqrt{2\pi}\delta^{d/2}\sigma^d} \end{align}

This is the reverse transition density for the Euler-Maruyama discretisation of the following SDE: dXt=[f+(Xt,Tt)+σ2XtlnpTt(Xt)]dt+σdWtdX_t = [-f^+(X_t, T-t) + \sigma^2 \nabla_{X_t}\ln p_{T-t}(X_t)] dt + \sigma dW_t

We can see that this SDE is very similar to the original SDE, but with the drift term f+f^+ replaced by f+-f^+ and an additional term σ2XtlnpTt(Xt)\sigma^2 \nabla_{X_t}\ln p_{T-t}(X_t) added. We can define this sum as the drift term ff^- of the reverse SDE and get a formula via this that relates the forward and reverse drift with each other: f(x,t)+f+(x,Tt)=σ2xlnpTt(x)f^-(x, t) + -f^+(x, T-t) = \sigma^2 \nabla_{x}\ln p_{T-t}(x)

2.2 Anderson’s Derivation

In the previous section we saw a “heuristic” derivation of the reverse transition density. In this section we will look at the more rigorous derivation of the reverse transition density that Anderson took in his paper from 1982. We will not look at the full derivation that spans section 3 and 4, but will focus on the “simpler” derivation in section 5 that they came up with first, but that did not include all the results they present in the paper (for us it is enough though).

Let us stort again with our general SDE dXt=μ(Xt,t)dt+σ(Xt,t)dWtdX_t = \mu(X_t, t)dt + \sigma(X_t, t)dW_t and the FPE for this SDE: tpt(xt)=xt[μ(xt)pt(xt)]+122x2[σ2(xt)pt(xt)]\frac{\partial}{\partial t} p_t(x_t) = - \frac{\partial}{\partial x_t}[\mu(x_t)p_t(x_t)] + \frac{1}{2} \frac{\partial^2}{\partial x^2}[\sigma^2(x_t)p_t(x_t)]

The FPE is also known as the forward Kolmogorov equation and describes the evolution of the probability density of the SDE.

However, there is also a backward Kolmogorov equation that is defined as tps(xsxt)=μ(xt)xt[ps(xsxt)]+12σ2(xt)2x2[ps(xsxt)]- \frac{\partial}{\partial t} p_s(x_s \mid x_t) = \mu(x_t) \frac{\partial}{\partial x_t}[p_s(x_s \mid x_t)] + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x^2}[p_s(x_s \mid x_t)]

where sts \geq t, i.e s=t+δs = t + \delta from our previous section. Note that we condtion on the value of the SDE at time tt and not on the value of the SDE at time ss and that we also differentiate with respect to xtx_t and not xsx_s. This is a bit confusing at first, but we can think of it in the following way: the equation describes how the probability density of the SDE at time ss evolves if we change the value of xtx_t at time tt. However, we do not know the corresponding drift and diffusion terms to form the reverse SDE. Anderson now used a similar chain rule factorisation to what we did in our heuristic sketch to derive these terms. ps(xs,xt)=ps(xsxt)pt(xt)take derivative and *(-1)tps(xs,xt)=t[ps(xsxt)pt(xt)]product ruletps(xs,xt)=t[ps(xsxt)]pt(xt)ps(xsxt)t[pt(xt)]\begin{align} p_s(x_s, x_t) &= p_s(x_s \mid x_t) p_t(x_t) \hspace{10px} \mid \text{take derivative and *(-1)}\\ -\frac{\partial}{\partial t} p_s(x_s, x_t) &= -\frac{\partial}{\partial t} [p_s(x_s \mid x_t) p_t(x_t)] \hspace{10px} \mid \text{product rule}\\ -\frac{\partial}{\partial t} p_s(x_s, x_t) &= -\frac{\partial}{\partial t} [p_s(x_s \mid x_t)] p_t(x_t) - p_s(x_s \mid x_t) \frac{\partial}{\partial t} [p_t(x_t)] \\ \end{align}

We can now insert the backward Kolmogorov equation into the first term on the RHS and the forward Kolmogorov equation into the second term on the RHS to get t[ps(xsxt)]pt(xt)ps(xsxt)t[pt(xt)]=(μ(xt)xt[ps(xsxt)]+12σ2(xt)2xt2[ps(xsxt)])pt(xt)+ps(xsxt)(xt[μ(xt)pt(xt)]122xt2[σ2(xt)pt(xt)])\begin{align} \textcolor{green}{ -\frac{\partial}{\partial t} [p_s(x_s \mid x_t)]} p_t(x_t) - p_s(x_s \mid x_t) \textcolor{orange}{\frac{\partial}{\partial t} [p_t(x_t)]} \\ = \textcolor{green} {\Big( \mu(x_t) \frac{\partial}{\partial x_t}[p_s(x_s \mid x_t)] + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] \Big)} p_t(x_t) \\ + p_s(x_s \mid x_t) \textcolor{orange}{ \Big( \frac{\partial}{\partial x_t} [\mu(x_t) p_t(x_t)] - \frac{1}{2} \frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)] \Big)}\\ \end{align}

We can now evaluate the derivatives for the KFE and KBE to insert them back into this equation.

For the KBE, we get xt[ps(xsxt)]=xt[ps(xs,xt)pt(xt)]quotient rule=xt[p(xs,xt)]pt(xt)p(xs,xt)xt[pt(xt)]pt2(xt)split nominator=xt[p(xs,xt)]pt(xt)p(xs,xt)xt[pt(xt)]pt2(xt)\begin{align} \frac{\partial}{\partial x_t} [p_s(x_s \mid x_t)] &= \frac{\partial}{\partial x_t} \Big[ \frac{p_s(x_s, x_t)}{p_t(x_t)} \Big] \hspace{10px} \mid \text{quotient rule}\\ &= \frac{\frac{\partial}{\partial x_t} [p(x_s, x_t)] p_t(x_t) - p(x_s, x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]}{p_t^2(x_t)} \hspace{10px} \mid \text{split nominator}\\ &= \frac{\frac{\partial}{\partial x_t} [p(x_s, x_t)]}{p_t(x_t)} - \frac{p(x_s, x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]}{p_t^2(x_t)}\\ \end{align}

Now we can do the same for the two derivatives of products in the KFE terms: xt[μ(xt)pt(xt)]=xt[μ(xt)]pt(xt)+μ(xt)xt[pt(xt)]2xt2[σ2(xt)pt(xt)]=2xt2[σ2(xt)]pt(xt)+2xt[σ2(xt)]xt[pt(xt)]+σ2(xt)2xt2[pt(xt)]\begin{align} \frac{\partial}{\partial x_t} [\mu(x_t)p_t(x_t)] &= \frac{\partial}{\partial x_t} [\mu(x_t)] p_t(x_t) + \mu(x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]\\ \frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t)p_t(x_t)] &= \frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t)] p_t(x_t) + 2 \frac{\partial}{\partial x_t} [\sigma^2(x_t)] \frac{\partial}{\partial x_t} [p_t(x_t)] + \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2} [p_t(x_t)]\\ \end{align}

Substituting these derivatives back into the equation above, we get t[ps(xs,xt)]=t[ps(xsxt)]pt(xt)ps(xsxt)t[pt(xt)]=(μ(xt)xt[ps(xsxt)]+12σ2(xt)2xt2[ps(xsxt)])pt(xt)+ps(xsxt)(xt[μ(xt)pt(xt)]122xt2[σ2(xt)pt(xt)])plug in derivatives=μ(xt)(xt[p(xs,xt)]pt(xt)p(xs,xt)xt[pt(xt)]pt2(xt))pt(xt)+12σ2(xt)2xt2[ps(xsxt)]pt(xt)+ps(xsxt)xt[μ(xt)]pt(xt)+ps(xsxt)μ(xt)xt[pt(xt)]12ps(xsxt)2xt2[σ2(xt)pt(xt)]fraction=μ(xt)(xt[p(xs,xt)]p(xs,xt)xt[pt(xt)]pt(xt))+12σ2(xt)2xt2[ps(xsxt)]pt(xt)+ps(xsxt)xt[μ(xt)]pt(xt)+ps(xsxt)μ(xt)xt[pt(xt)]12ps(xsxt)2xt2[σ2(xt)pt(xt)]Bayes=μ(xt)(xt[p(xs,xt)]ps(xsxt)xt[pt(xt)])+12σ2(xt)2xt2[ps(xsxt)]pt(xt)+p(xs,xt)xt[μ(xt)]+ps(xsxt)μ(xt)xt[pt(xt)]12ps(xsxt)2xt2[σ2(xt)pt(xt)]cancel terms=μ(xt)xt[p(xs,xt)]+p(xs,xt)xt[μ(xt)]+12σ2(xt)2xt2[ps(xsxt)]pt(xt)12ps(xsxt)2xt2[σ2(xt)pt(xt)]product rule=xt[μ(xt)p(xs,xt)]+12σ2(xt)2xt2[ps(xsxt)]pt(xt)12ps(xsxt)2xt2[σ2(xt)pt(xt)]\begin{align} - \frac{\partial}{\partial t} [p_s(x_s, x_t)] = -\frac{\partial}{\partial t} [p_s(x_s \mid x_t)] p_t(x_t) - p_s(x_s \mid x_t)\frac{\partial}{\partial t} [p_t(x_t)] \\ = \Big( \mu(x_t) \textcolor{green} {\frac{\partial}{\partial x_t}[p_s(x_s \mid x_t)]} + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] \Big) p_t(x_t) \\ + p_s(x_s \mid x_t) \Big( \textcolor{orange}{\frac{\partial}{\partial x_t} [\mu(x_t) p_t(x_t)]} - \frac{1}{2} \textcolor{lime}{\frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)]} \Big) \hspace{10px} \mid \text{plug in derivatives}\\ = \mu(x_t) \textcolor{green} {\Big( \frac{\frac{\partial}{\partial x_t} [p(x_s, x_t)]}{p_t(x_t)} - \frac{p(x_s, x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]}{p_t^2(x_t)} \Big)} p_t(x_t) + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] p_t(x_t) \\ + \textcolor{orange}{p_s(x_s \mid x_t) \frac{\partial}{\partial x_t} [\mu(x_t)] p_t(x_t) + p_s(x_s \mid x_t) \mu(x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]} - \frac{1}{2} p_s(x_s \mid x_t) \textcolor{lime}{\frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)]} \mid \text{fraction}\\ = \mu(x_t) \textcolor{green} {\Big( \frac{\partial}{\partial x_t} [p(x_s, x_t)] - \frac{p(x_s, x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]}{p_t(x_t)} \Big)} + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] p_t(x_t) \\ + \textcolor{orange}{p_s(x_s \mid x_t) \frac{\partial}{\partial x_t} [\mu(x_t)] p_t(x_t) + p_s(x_s \mid x_t) \mu(x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]} - \frac{1}{2} p_s(x_s \mid x_t) \textcolor{lime}{\frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)]} \mid \text{Bayes}\\ = \mu(x_t) \textcolor{green} {\Big( \frac{\partial}{\partial x_t} [p(x_s, x_t)] - p_s(x_s \mid x_t) \frac{\partial}{\partial x_t} [p_t(x_t)] \Big)} + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] p_t(x_t) \\ + \textcolor{orange}{p(x_s, x_t) \frac{\partial}{\partial x_t} [\mu(x_t)] + p_s(x_s \mid x_t) \mu(x_t) \frac{\partial}{\partial x_t} [p_t(x_t)]} - \frac{1}{2} p_s(x_s \mid x_t) \textcolor{lime}{\frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)]} \mid \text{cancel terms}\\ = \textcolor{green} {\mu(x_t) \frac{\partial}{\partial x_t} [p(x_s, x_t)]} \textcolor{orange}{+ p(x_s, x_t) \frac{\partial}{\partial x_t} [\mu(x_t)]} + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] p_t(x_t) \\ - \frac{1}{2} p_s(x_s \mid x_t) \textcolor{lime}{\frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)]} \mid \text{product rule}\\ = \textcolor{green} {\frac{\partial}{\partial x_t}[\mu(x_t)p(x_s,x_t)]} + \frac{1}{2} \sigma^2(x_t) \frac{\partial^2}{\partial x_t^2}[p_s(x_s \mid x_t)] p_t(x_t) \\ - \frac{1}{2} p_s(x_s \mid x_t) \textcolor{lime}{\frac{\partial^2}{\partial x_t^2} [\sigma^2(x_t) p_t(x_t)]}\\ \end{align}

Among all this algebra, remember that what we want is get a FPE that tells us the equivalent reverse SDE terms (drift and diffusion term). We already have a first derivative xt\frac{\partial}{\partial x_t} that looks like a similar term in the FPE, but we still need to work with the rest in order to get a proper second order derivative term that matches the FPE.

To do this, we can see that the last two terms can be generated by the product rule of the following equation: 12xt2[p(xs,xt)σ2(xt)]=12xt2[p(xsxt)p(xt)σ2(xt)]=12xt2p(xsxt)p(xt)σ2(xt)+xt[p(xt)σ2(xt)]xtp(xsxt)+12xt2[p(xt)σ2(xt)]p(xsxt)=12σ2(xt)xt2p(xsxt)p(xt)(1)+xt[p(xt)σ2(xt)]xtp(xsxt)+12p(xsxt)xt2[p(xt)σ2(xt)](2)\begin{align} & \frac{1}{2} \partial_{x_t}^2 \left[ p(x_s, x_t) \sigma^2(x_t) \right] \\ = & \frac{1}{2} \partial_{x_t}^2 \left[ p(x_s \mid x_t) p(x_t) \sigma^2(x_t) \right] \\ = & \frac{1}{2} \partial_{x_t}^2 p(x_s \mid x_t) p(x_t) \sigma^2(x_t) + \partial_{x_t} \left[ p(x_t) \sigma^2(x_t) \right] \partial_{x_t} p(x_s\mid x_t) + \frac{1}{2} \partial_{x_t}^2 \left[ p(x_t) \sigma^2(x_t) \right] p(x_s\mid x_t) \\ = & \underbrace{\frac{1}{2} \sigma^2(x_t) \partial_{x_t}^2 p(x_s \mid x_t) p(x_t)}_{(1)} + \partial_{x_t} \left[ p(x_t) \sigma^2(x_t) \right] \partial_{x_t} p(x_s\mid x_t) + \underbrace{\frac{1}{2} p(x_s\mid x_t) \partial_{x_t}^2 \left[ p(x_t) \sigma^2(x_t) \right]}_{(2)} \end{align}

Term (1) and (2) are exactly the terms we need for the FPE. After completing the square with the center term, we can use exactly this equation to simplify our current expression: tp(xs,xt)=xt[μ(xt) p(xs,xt)]+12 σ2(xt) xt2p(xsxt) p(xt)12p(xsxt)xt2[σ2(xt) p(xt)]=xt[μ(xt) p(xs,xt)]+12 σ2(xt) p(xt) xt2p(xsxt)12p(xsxt)xt2[σ2(xt) p(xt)]12X=X+12X±xtp(xsxt)xt[p(xt)σ2(xt)]complete the square=xt[μ(xt) p(xs,xt)]+12 σ2(xt) xt2p(xsxt) p(xt)p(xsxt)xt2[σ2(xt) p(xt)]+12p(xsxt)xt2[σ2(xt) p(xt)]12X=X+12X±xtp(xsxt)xt[p(xt)σ2(xt)]=xt[μ(xt) p(xs,xt)]+12xt2[p(xsxt)p(xt)σ2(xt)]p(xsxt)xt2[σ2(xt) p(xt)]xtp(xsxt)xt[p(xt)σ2(xt)]xt[p(xsxt)xt[σ2(xt) p(xt)]] (product rule) =xt[μ(xt) p(xs,xt)]+12xt2[p(xs,xt)σ2(xt)]xt[p(xsxt)xt[σ2(xt) p(xt)]].\begin{align} -\partial_t p(x_s, x_t) = & \partial_{x_t} \left[ \mu(x_t) \ p(x_s, x_t) \right] \\ & + \frac{1}{2} \ \sigma^2(x_t) \ \partial_{x_t}^2 p(x_s \mid x_t) \ p(x_t) - \frac{1}{2} p(x_s\mid x_t) \partial_{x_t}^2 \left[ \sigma^2(x_t) \ p(x_t) \right] \\ = & \partial_{x_t} \left[ \mu(x_t) \ p(x_s, x_t) \right] \\ & + \frac{1}{2} \ \sigma^2(x_t) \ p(x_t) \ \partial_{x_t}^2 p(x_s \mid x_t) \underbrace{ - \frac{1}{2} p(x_s\mid x_t) \partial_{x_t}^2 \left[ \sigma^2(x_t) \ p(x_t) \right] }_{-\frac{1}{2} X = -X + \frac{1}{2} X} \\ & \underbrace{\pm \partial_{x_t} p(x_s \mid x_t) \partial_{x_t} \left[ p(x_t) \sigma^2(x_t) \right]}_{\text{complete the square}} \\ = & \partial_{x_t} \left[ \mu(x_t) \ p(x_s, x_t) \right] \textcolor{red}{+ \frac{1}{2} \ \sigma^2(x_t) \ \partial_{x_t}^2 p(x_s \mid x_t) \ p(x_t)} \\ & \underbrace{ - p(x_s\mid x_t) \partial_{x_t}^2 \left[ \sigma^2(x_t) \ p(x_t) \right] + \textcolor{red}{\frac{1}{2} p(x_s\mid x_t) \partial_{x_t}^2 \left[ \sigma^2(x_t) \ p(x_t) \right]} }_{-\frac{1}{2} X = -X + \frac{1}{2} X} \\ & \textcolor{red}{\pm \partial_{x_t} p(x_s \mid x_t) \partial_{x_t} \left[ p(x_t) \sigma^2(x_t) \right]} \\ = & \partial_{x_t} \left[ \mu(x_t) \ p(x_s, x_t) \right] + \textcolor{red}{\frac{1}{2} \partial_{x_t}^2 \left[ p( x_s \mid x_t) p(x_t) \sigma^2(x_t) \right]} \\ & \underbrace{- p(x_s\mid x_t) \partial_{x_t}^2 \left[ \sigma^2(x_t) \ p(x_t) \right] - \partial_{x_t} p(x_s \mid x_t) \partial_{x_t} \left[ p(x_t) \sigma^2(x_t) \right]}_{ - \partial_{x_t} \left[ p(x_s\mid x_t) \partial_{x_t} \left[ \sigma^2(x_t) \ p(x_t) \right] \right] \text{ (product rule) } } \\ = & \partial_{x_t} \left[ \mu(x_t) \ p(x_s, x_t) \right] + \frac{1}{2} \partial_{x_t}^2 \left[ p( x_s , x_t) \sigma^2(x_t) \right] \\ & - \partial_{x_t} \left[ p(x_s\mid x_t) \partial_{x_t} \left[ \sigma^2(x_t) \ p(x_t) \right] \right]. \end{align}

To make this look like a FPE now, we just need to combine the conditional and joint probability terms in the first order derivatives: tp(xs,xt)=xt[μ(xt) p(xs,xt)p(xsxt)xt[σ2(xt) p(xt)]]+12xt2[p(xs,xt)σ2(xt)]=xt[p(xs,xt)(μ(xt)1p(xt)xt[σ2(xt) p(xt)])]+12xt2[p(xs,xt)σ2(xt)]=xt[p(xs,xt)(μ(xt)+1p(xt)xt[σ2(xt) p(xt)])]+12xt2[p(xs,xt)σ2(xt)]\begin{align} -\partial_t p(x_s, x_t) = & \partial_{x_t} \left[ \mu(x_t) \ p(x_s, x_t) - p(x_s\mid x_t) \partial_{x_t} \left[ \sigma^2(x_t) \ p(x_t) \right] \right] \\ & + \frac{1}{2} \partial_{x_t}^2 \left[ p( x_s , x_t) \sigma^2(x_t) \right] \\ = & \partial_{x_t} \Big[ p(x_s, x_t) \left( \mu(x_t) - \frac{1}{p(x_t)} \partial_{x_t} \left[ \sigma^2(x_t) \ p(x_t) \right] \right) \Big] \\ & + \frac{1}{2} \partial_{x_t}^2 \left[ p( x_s , x_t) \sigma^2(x_t) \right] \\ = & - \partial_{x_t} \Big[ p(x_s, x_t) \left( -\mu(x_t) + \frac{1}{p(x_t)} \partial_{x_t} \left[ \sigma^2(x_t) \ p(x_t) \right] \right) \Big] \\ & + \frac{1}{2} \partial_{x_t}^2 \left[ p( x_s , x_t) \sigma^2(x_t) \right] \end{align}

We can see that this equation looks like a FPE/FKE. However, it still involves the joint density p(xs,xt)p( x_s , x_t), but we want it to only consist of the conditional and marginal densities. To do this, we can marginalise over xsx_s via Leibniz’ rule to get tp(xt)=xt[p(xt)(μ(xt)+1p(xt)xt[σ2(xt) p(xt)])]+12xt2[p(xt)σ2(xt)]-\partial_t p(x_t) = - \partial_{x_t} \Big[ p(x_t) \left( -\mu(x_t) + \frac{1}{p(x_t)} \partial_{x_t} \left[ \sigma^2(x_t) \ p(x_t) \right] \right) \Big] + \frac{1}{2} \partial_{x_t}^2 \left[ p( x_t) \sigma^2(x_t) \right]

We can now use τ=Tt\tau =T-t as time reversal to get a FPE for the reverse SDE: tp(xt)=τp(xTτ)=xTτ[p(xTτ)(μ(xTτ)+1p(xTτ)xTτ[σ2(xTτ) p(xTτ)])]+12xTτ2[p(xTτ)σ2(xTτ)]-\partial_t p(x_t) = \partial_\tau p(x_{T-\tau}) = \partial_{x_{T-\tau}} \Big[ p(x_{T-\tau}) \left( \mu(x_{T-\tau}) + \frac{1}{p(x_{T-\tau})} \partial_{x_{T-\tau}} \left[ \sigma^2(x_{T-\tau}) \ p(x_{T-\tau}) \right] \right) \Big] + \frac{1}{2} \partial_{x_{T-\tau}}^2 \left[ p( x_{T-\tau}) \sigma^2(x_{T-\tau}) \right]

We can now use this FPE to get the corresponding reverse SDE that can be solved backward in time: dXt=[μ(xTτ)+1p(xTτ)xTτ[σ2(xTτ) p(xTτ)]]dτ+σ(xTτ)dW~tdX_t = \big[ - \mu(x_{T-\tau}) + \frac{1}{p(x_{T-\tau})} \partial_{x_{T-\tau}} \left[ \sigma^2(x_{T-\tau}) \ p(x_{T-\tau}) \right] \big] d\tau + \sigma(x_{T-\tau}) d\tilde{W}_t

where W~t\tilde{W}_t is a reverse Brownian motion that flows backward in time.

We can again simplify this by assuming a scalar diffusion term and no dependency on xTτx_{T-\tau} to get dXt=[μ(xTτ)+1p(xTτ)σ2xTτlogp(xTτ)]dτ+σ(xTτ)dW~tcancel p(xTτ)dXt=[μ(xTτ)+σ2xTτlogp(xTτ)]dτ+σ(xTτ)dW~tlog-derivative trickdXt=[μ(xTτ)+σ2xTτlogp(xTτ)]dτ+σ(xTτ)dW~treverse time\begin{align} dX_t = \big[ - \mu(x_{T-\tau}) + \frac{1}{p(x_{T-\tau})} \sigma^2 \partial_{x_{T-\tau}} \log p(x_{T-\tau}) \big] d\tau + \sigma(x_{T-\tau}) d\tilde{W}_t \hspace{10px} \mid \text{cancel }p(x_{T-\tau})\\ dX_t = \big[ - \mu(x_{T-\tau}) + \sigma^2 \partial_{x_{T-\tau}} \log p(x_{T-\tau}) \big] d\tau + \sigma(x_{T-\tau}) d\tilde{W}_t \hspace{10px} \mid \text{log-derivative trick}\\ dX_t = \big[ - \mu(x_{T-\tau}) + \sigma^2 \partial_{x_{T-\tau}} \log p(x_{T-\tau}) \big] d\tau + \sigma(x_{T-\tau}) d\tilde{W}_t \hspace{10px} \mid \text{reverse time}\\ \end{align}

We can see that similarly to our heuristic derivation, the reverse SDE contains the score term xTτlogp(xTτ)\partial_{x_{T-\tau}} \log p(x_{T-\tau}), which we can also write as xtlogp(xt)\partial_{x_{t}} \log p(x_{t}) since τ=Tt\tau = T-t.

Here we see a small but important difference between the formulation of the reverse-time SDE in our heuristic treatment and in the derivation by Anderson. It coincides with different conventions of how to represent time reversals.

time_reversals

  1. We can choose to “turn around” at time T and run our SDE in a “forward” manner, but with the time tt running backwards. This is the convention we used in our heuristic derivation and the one that is used in e.g. De Bortoli 2021. In this case, the reverse SDE is given by
dXt=f(Xt,t)dt+σdWtdX_t = f^-(X_t, t)dt + \sigma dW_t
  1. An alternative is not to turn around at time T, but “run backwards” from time T, i.e. formulating the SDE using a reversed Brownian motion dW~td\tilde{W}_t. This is the convention that Anderson used in his derivation and the one that is used in e.g. Song et al. 2021. In this case, the reverse SDE is given by
dXt=f(X~t,t)dt+σdW~tdX_t = f^-(\tilde{X}_t, t)dt + \sigma d\tilde{W}_t

where X~t\tilde{X}_t and W~t\tilde{W}_t are the reverse SDE and Brownian motion, respectively.

3. Score-based modelling via SDEs

How do we now use this time-reversal formalism for generative modelling? To understand that, let us shortly discuss what score matching is.

3.1 From score matching to denoising score matching

When we train a generative model, we want to learn a probability density pθ(x)p_{\theta}(x) that matches the true data distribution pdata(x)p_{data}(x). We can do this by maximising the log-likelihood of the data under the model: maxθExpdatalogpθ(x)\max_{\theta} \mathbb{E}_{x \sim p_{data}} \log p_{\theta}(x)

However, this requires pθ(x)p_{\theta}(x) to be normalised, which is often intractable. There are many ways around this limitation, for example restricted model architectures in autoregressive models or approximating this constant via variational inference in VAEs. Here, we use the score function xlogpθ(x)\nabla_x \log p_{\theta}(x) to learn the model. The score function is the gradient of the log-probability density and can be seen as a vector field that points in the direction of the steepest ascent of the density. We see tat this allows us to learn the model without having to normalise the density by the normalisation constant ZθZ_{\theta}: sθ(x)=xlogpθ(x)=xlog(fθ(x)Zθ)=xlogfθ(x)xlogZθ=0=xlogfθ(x)\begin{align} s_{\theta}(x) &= \nabla_x \log p_{\theta}(x) \\ &= \nabla_x \log\big( \frac{f_{\theta}(x)}{Z_{\theta}}\big) \\ &= \nabla_x \log f_{\theta}(x) - \underbrace{\nabla_x \log Z_{\theta}}_{=0} \\ &= \nabla_x \log f_{\theta}(x) \\ \end{align}

We can then use this score function to learn the model by minimising the following loss function: Ep(x)[xlogp(x)sθ(x)22]\mathbb{E}_{p(\mathbf{x})}[\| \nabla_\mathbf{x} \log p(\mathbf{x}) - \mathbf{s}_\theta(\mathbf{x}) \|_2^2]

There is a problem, though: we do not know the ground truth score function xlogpdata(x)\nabla_x \log p_{data}(x), so we cannot directly minimise this loss.

However, we can use Score matching to minimise this loss without requiring access to the ground truth score function. There are different variants to score matching to make it more efficient/practical. We are not going into detail about the derivation and variants of score matching; if you want to understand more about the background, however, the papers from Aapo Hyvärinen are an excellent resoruce.

Here, we will focus on denoising score matching, where the idea is to add a bit of noise to each data point to smooth the data distribution and allow better score estimates. If we think of the noise addition in multiple (in fact infinitely many) steps, we can think of this noising process as an SDE (in practice an OU process is often used). We can then use the reverse SDE to get the corresponding reverse-time SDE that we can use to generate samples.

As you remember, the term that appeared in the reverse SDE was exactly our score function, xlogp(x)\nabla_x \log p(x). To learn this, we use a training objective where we weight the Fisher divergence terms of all the different time points with a weighting function λ(t)\lambda(t): EtU(0,T)Ept(x)[λ(t)xlogpt(x,t)sθ(x,t)22]\mathbb{E}_{t \in \mathcal{U}(0,T)}\mathbb{E}_{p_t(\mathbf{x})}[ \lambda(t) \left\lVert \nabla_\mathbf{x} \log p_t(\mathbf{x},t) - \mathbf{s}_\theta(\mathbf{x},t) \right\rVert_2^2]

where U(0,T)\mathcal{U}(0,T) is a uniform distribution over the interval [0,T][0,T].

This loss allows us to train practical score-based models (also called denoising diffusion probabilistic models, DDPMs) that can be used to generate samples. It also allows us to make some important connections between our training objective and maximum likelihood training/KL divergence minimisation: KL(p0(x)pθ(x))T2EtU(0,T)Ept(x)[λ(t)xlogpt(x)sθ(x,t)22]+KL(pTπ)\operatorname{KL}(p_0(\mathbf{x})\|p_\theta(\mathbf{x})) \leq \frac{T}{2}\mathbb{E}_{t \in \mathcal{U}(0, T)}\mathbb{E}_{p_t(\mathbf{x})}[\lambda(t) \| \nabla_\mathbf{x} \log p_t(\mathbf{x}) - \mathbf{s}_\theta(\mathbf{x}, t) \|_2^2] + \operatorname{KL}(p_T \mathrel\| \pi)

song_ddpm

3.2 Closed-from transition densities for fast training

It would be relatively annoying if we would have to simulate the noising of the sample via an Euler-Maruyama approximation for every training point and then backpropagate through this. Instead, since we restrict ourselves to certain SDEs that have analytically known Gaussian transition densities, we can use closed-form transition densities to directly “jump” from data samples to an arbitrary time point tt. This makes training score-based models much more efficient.

lipman_diffusion

Image from this talk by Yaron Lipman. We can formulate score matching in terms of conditional transition densities for which we have closed-form solutions. We can then use these closed-form solutions to directly “jump” from data samples to an arbitrary time point tt.

Wait”, you may say at this point, “the score matching loss involving the conditional properties in this picture is different to the one we talked about before! How do I know they actually are the same?” You are right, this is not directly obvious! There is no reason to believe a priori that minimising E[xlnpt0(XtX0)sθ(t,Xt)2]\mathbb{E} \left[\left\lVert \nabla_x \ln p_{t \mid 0}(X_t \mid X_0)- s_{\theta}(t, X_t)\right\rVert^2 \right] is the same as minimising EtU(0,T)Ept(x)[xlogpt(x,t)sθ(x,t)22]\mathbb{E}_{t \in \mathcal{U}(0,T)}\mathbb{E}_{p_t(\mathbf{x})}[\left\lVert \nabla_\mathbf{x} \log p_t(\mathbf{x},t) - \mathbf{s}_\theta(\mathbf{x},t) \right\rVert_2^2].

However, we can show that this is indeed the case. For this, we make use of the fact that the conditional expectation is the minimiser of the mean squared error:

The optimal predictor of X as a function of Y (Hilbert projection)

f(Y)=arg minfis  measurableE[(Xf(Y))2]\begin{align} f^{*}(Y) = \argmin_{f-\mathrm{is\; measurable}} \mathbb{E}[(X- f(Y))^2] \end{align}

is given by the conditional expectation

f(Y)=E(XY)f^{*}(Y) = \mathbb{E}(X \mid Y)

For the case of martingales, we can show that the optimal predictor of the future as a function of the past in a martingale is given by the past itself: f(t,x)=arg minsis  measurableE[(Xt+δf(Xt))2]//=E(Xt+δXt)=Xt\begin{align} f^{*}(t,x)&=\argmin_{s-\mathrm{is\; measurable}} \mathbb{E}[(X_{t+\delta}- f(X_t))^2] // &= \mathbb{E}(X_{t+\delta} \mid X_t) \\ &= X_t \end{align}

We can then use this to show that minimising the MSE between our model estimate and the score function is equivalent to minimising the MSE between our model estimate and the conditional score function: sθ(t,x)=arg minsis  measurableE[0Txlnpt0(XtX0)sθ(t,Xt)2dt]conditional expectation=EX0Xt[xlnpt0(XtX0)Xt=x]def. of expectation=p0t(x0x)xlnpt0(xx0)dx0Bayes=pt0(xx0)p0(x0)pt(x)xlnpt0(xx0)dx0log-deriv. trick=1pt(x)p0(x0)xpt0(xx0)dx0dom. convergence theorem=1pt(x)xp0(x0)pt0(xx0)dx0marginalise x0=1pt(x)xpt(x)log deriv. trick=xlnpt(x)\begin{align} s^*_{\theta}(t,x)&=\argmin_{s-\mathrm{is\; measurable}} \mathbb{E} \left[\int_0^T \left\lVert \nabla_x \ln p_{t \mid 0}(X_t \mid X_0)- s_{\theta}(t, X_t)\right\rVert^2 \mathrm{d}t\right] \hspace{10px} \mid \text{conditional expectation}\\ &= \mathbb{E}_{X_0 \mid X_t} [ \nabla_x \ln p_{t \mid 0}(X_t \mid X_0) \mid X_t =x] \hspace{10px} \mid \text{def. of expectation}\\ &= \int p_{0 \mid t}(x_0 \mid x) \nabla_x \ln p_{t \mid 0}(x \mid x_0) \mathrm{d}x_0 \hspace{10px} \mid \text{Bayes}\\ &= \int \frac{p_{t \mid 0}(x \mid x_0)p_0(x_0)}{ p_t(x)} \nabla_x \ln p_{t \mid 0}(x \mid x_0) \mathrm{d}x_0 \hspace{10px} \mid \text{log-deriv. trick}\\ &= \frac{1}{p_t(x)}\int p_0(x_0) \nabla_x p_{t \mid 0}(x \mid x_0) \mathrm{d}x_0 \hspace{10px} \mid \text{dom. convergence theorem}\\ &= \frac{1}{p_t(x)} \nabla_x \int p_0(x_0) p_{t \mid 0}(x \mid x_0) \mathrm{d}x_0 \hspace{10px} \mid \text{marginalise } x_0\\ &= \frac{1}{p_t(x)} \nabla_x p_t(x) \hspace{10px} \mid \text{log deriv. trick}\\ &= \nabla_x \ln p_t(x) \end{align}

4. Probability flow ODE

You saw that we used Langevin sampling-like techniques to simulate our reverse SDE. In Song et al. 2021, it was shown that you do not necessarily need to do that; instead of modelling a stochastic process, we can replace this backward diffusion process with its corresponding deterministic process (i.e. replace an SDE with an ODE) and still have the same marginal probability densities {pt(x)}t=0T\{p_t(x)\}_{t=0}^T. This is known as the probability flow ODE.

Definition: Every stochastic process described by an SDE has a corresponding deterministic process described by an ODE that has the same marginal probability densities {pt(x)}t=0T\{p_t(x)\}_{t=0}^T. This process is called the probability flow ODE. For a general SDE of the form dXt=μ(Xt,t)dt+σ(Xt,t)dWtdX_t = \mu(X_t, t)dt + \sigma(X_t,t)dW_t, the corresponding ODE is given by

dXt=[μ(Xt,t)12x[σ(Xt,t)σ(Xt,t)T]12σ(Xt,t)σ(Xt,t)Txlogpt(x)]dtdX_t = \big[ \mu(X_t, t) -\frac{1}{2} \nabla_x [\sigma(X_t, t)\sigma(X_t,t)^T] - \frac{1}{2} \sigma(X_t, t)\sigma(X_t,t)^T \nabla_x \log p_t(x) \big]dt

For the example of the backward SDE dXt=[μ(Xt,t)σ2(t)xlogpt(Xt)]dt+σ(t)dWˉtdX_t = [\mu(X_t, t) - \sigma^2(t) \nabla_x \log p_t(X_t)]dt + \sigma(t)d\bar{W}_t (where we assume a scalar diffusion term with no dependency on XtX_t), the corresponding ODE is given by dXt=[μ(Xt,t)12σ2(t)xlogpt(Xt)]dtdX_t = [\mu(X_t, t) - \frac{1}{2}\sigma^2(t) \nabla_x \log p_t(X_t)]dt

Looking at it, the ODE looks very similar to the SDE, but with the stochastic term dWˉtd\bar{W}_t removed and an additional factor of 12\frac{1}{2} in front of the diffusion term. This is not a coincidence, but a direct consequence of the fact that the SDE and ODE have the same marginal probability densities. Let us look at the derivation for this.

Let’s start again with our general SDE dx=μ(x,t)dt+σ(x,t)dWtdx = \mu(x, t)dt + \sigma(x, t)dW_t, where we assume σ(x,t)\sigma(x, t) to be matrix-valued and dependent on xx. We know that the Focker-Planck Equation of this general SDE is given by tpt(x)=x[μ(x,t)pt(x)]+xx[12σ(x,t)σ(x,t)Tpt(x)]\frac{\partial}{\partial t} p_t(x) = - \frac{\partial}{\partial x}[\mu(x,t)p_t(x)] + \frac{\partial}{\partial x}\frac{\partial}{\partial x}[\frac{1}{2}\sigma(x,t)\sigma(x,t)^T p_t(x)]

where pt(x)p_t(x) is the probability density of the SDE at time tt.

We can now use the product rule to split the second term on the RHS into two terms: tpt(x)=x[μ(x,t)pt(x)]+12x[x[σ(x,t)σ(x,t)T]pt(x)]+12x[σ(x,t)σ(x,t)Tpt(x)x]\frac{\partial}{\partial t} p_t(x) = - \frac{\partial}{\partial x}[\mu(x,t)p_t(x)] + \frac{1}{2} \frac{\partial}{\partial x}\Big[\nabla_x[\sigma(x,t)\sigma(x,t)^T]p_t(x)\Big] + \frac{1}{2} \frac{\partial}{\partial x}\Big[\sigma(x,t)\sigma(x,t)^T\frac{\partial p_t(x)}{\partial x}\Big]

Using the log-derivative trick to rewrite the last term, we get tpt(x)=x[μ(x,t)pt(x)]+12x[x[σ(x,t)σ(x,t)T]pt(x)]+12x[σ(x,t)σ(x,t)Tpt(x)xlogpt(x)]\frac{\partial}{\partial t} p_t(x) = - \frac{\partial}{\partial x}[\mu(x,t)p_t(x)] + \frac{1}{2} \frac{\partial}{\partial x}\Big[\nabla_x[\sigma(x,t)\sigma(x,t)^T]p_t(x)\Big] + \frac{1}{2} \frac{\partial}{\partial x}\Big[\sigma(x,t)\sigma(x,t)^T p_t(x) \nabla_x \log p_t(x)\Big]

Pulling the pt(x)p_t(x) and the x\frac{\partial}{\partial x} outside the whole expression, we get tpt(x)=xpt(x)[μ(x,t)12x[σ(x,t)σ(x,t)T]12σ(x,t)σ(x,t)Tx[logpt(x)]]\frac{\partial}{\partial t} p_t(x) = - \frac{\partial}{\partial x} p_t(x) \Big[ \mu(x,t)- \frac{1}{2} \nabla_x [\sigma(x,t)\sigma(x,t)^T] - \frac{1}{2}\sigma(x,t)\sigma(x,t)^T \nabla_x [\log p_t(x)]\Big]

If we now define the term in brackets on the RHS as μ~(x,t)pt(x)\tilde{\mu}(x,t)p_t(x), we get tpt(x)=xμ~(x,t)pt(x)\frac{\partial}{\partial t} p_t(x) = - \frac{\partial}{\partial x} \tilde{\mu}(x,t)p_t(x)

Looking at this equation, it looks again like a FPE, but for an SDE with a diffusion term σ~(x,t)=0\tilde{\sigma}(x,t) = 0. In that case, the FPE is also known as the Liouville equation and describes the evolution of the probability density of a deterministic process (i.e. an ODE) of the form dx=μ~(x,t)dtdx = \tilde{\mu}(x, t)dt

Via this reasoning, we showed that the SDE and the probability flow ODE have the same marginal probability densities.

If we assume that the diffusion term does not depend on xx, the derivative of the diffusion term x[σ(x,t)σ(x,t)T]\nabla_x [\sigma(x,t)\sigma(x,t)^T] is zero. If we assume in addition that the diffusion term is scalar, the term 12σ(x,t)σ(x,t)T\frac{1}{2}\sigma(x,t)\sigma(x,t)^T simplifies to 12σ2(t)\frac{1}{2}\sigma^2(t) and we get the probability flow ODE discussed before.

5. Flow Matching

We just saw that we can replace the stochastic process of the reverse SDE with a deterministic process that has the same marginal probability densities. An interesting question follows: if we can replace the reverse process with a deterministic process, can we also replace the forward process with a deterministic process and just train the whole model deterministically? This is exactly what Flow Matching does. To understand where it comes from, we first need to look at a class of models called continuous normalising flows.

5.1 Continuous normalising flows (CNFs)

FF

cnf

Image from this talk by Yaron Lipman.

5.2 Simulation-free training of CNFs

5.3 Comparison to score-based models

flow_matching_vs_diffusion

Image from this talk by Yaron Lipman.

5.4 Stochastic Interpolants - Unifying stochastic and deterministic generative models

You saw that despite the setup vor CNFs and score-based models being quite different at the start, their final solutions are quite similar. Recently, people tried to formalise this connection more explicitly and unify the two approaches. One example of such an effort is the work on Stochastic Interpolants by Michael Albergo and others (here a talk where they present the work).

stochastic_interpolants

Image from the stochastic interpolant paper.

Credits

Much of the logic in the lecture is based on the Oksendal as well as the Särkkä & Solin book. Title image from Wikipedia.