生成扩散模型漫谈(五):一般框架之SDE篇
By 苏剑林 | 2022-08-03 | 206599位读者 |在写生成扩散模型的第一篇文章时,就有读者在评论区推荐了宋飏博士的论文《Score-Based Generative Modeling through Stochastic Differential Equations》,可以说该论文构建了一个相当一般化的生成扩散模型理论框架,将DDPM、SDE、ODE等诸多结果联系了起来。诚然,这是一篇好论文,但并不是一篇适合初学者的论文,里边直接用到了随机微分方程(SDE)、Fokker-Planck方程、得分匹配等大量结果,上手难度还是颇大的。
不过,在经过了前四篇文章的积累后,现在我们可以尝试去学习一下这篇论文了。在接下来的文章中,笔者将尝试从尽可能少的理论基础出发,尽量复现原论文中的推导结果。
随机微分 #
在DDPM中,扩散过程被划分为了固定的$T$步,还是用《生成扩散模型漫谈(一):DDPM = 拆楼 + 建楼》的类比来说,就是“拆楼”和“建楼”都被事先划分为了$T$步,这个划分有着相当大的人为性。事实上,真实的“拆”、“建”过程应该是没有刻意划分的步骤的,我们可以将它们理解为一个在时间上连续的变换过程,可以用随机微分方程(Stochastic Differential Equation,SDE)来描述。
为此,我们用下述SDE描述前向过程(“拆楼”):
\begin{equation}d\boldsymbol{x} = \boldsymbol{f}_t(\boldsymbol{x}) dt + g_t d\boldsymbol{w}\label{eq:sde-forward}\end{equation}
相信很多读者都对SDE很陌生,笔者也只是在硕士阶段刚好接触过一段时间,略懂皮毛。不过不懂不要紧,我们只需要将它看成是下述离散形式在$\Delta t\to 0$时的极限:
\begin{equation}\boldsymbol{x}_{t+\Delta t} - \boldsymbol{x}_t = \boldsymbol{f}_t(\boldsymbol{x}_t) \Delta t + g_t \sqrt{\Delta t}\boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}\sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})\label{eq:sde-discrete}\end{equation}
再直白一点,如果假设拆楼需要$1$天,那么拆楼就是$\boldsymbol{x}$从$t=0$到$t=1$的变化过程,每一小步的变化我们可以用上述方程描述。至于时间间隔$\Delta t$,我们并没有做特殊限制,只是越小的$\Delta t$意味着是对原始SDE越好的近似,如果取$\Delta t=0.001$,那就对应于原来的$T=1000$,如果是$\Delta t = 0.01$则对应于$T=100$,等等。也就是说,在连续时间的SDE视角之下,不同的$T$是SDE不同的离散化程度的体现,它们会自动地导致相似的结果,我们不需要事先指定$T$,而是根据实际情况下的精确度来取适当的$T$进行数值计算。
所以,引入SDE形式来描述扩散模型的本质好处是“将理论分析和代码实现分离开来”,我们可以借助连续性SDE的数学工具对它做分析,而实践的时候,则只需要用任意适当的离散化方案对SDE进行数值计算。
对于式$\eqref{eq:sde-discrete}$,读者可能比较有疑惑的是为什么右端第一项是$\mathcal{O}(\Delta t)$的,而第二项是$\mathcal{O}(\sqrt{\Delta t})$的?也就是说为什么随机项的阶要比确定项的阶要高?这个还真不是那么容易解释,也是SDE比较让人迷惑的地方之一。简单来说,就是$\boldsymbol{\varepsilon}$一直服从标准正态分布,如果随机项的权重也是$\mathcal{O}(\Delta t)$,那么由于标准正态分布的均值为$\boldsymbol{0}$、协方差为$ \boldsymbol{I}$,临近的随机效应会相互抵消掉,要放大到$\mathcal{O}(\sqrt{\Delta t})$才能在长期结果中体现出随机效应的作用。
逆向方程 #
用概率的语言,式$\eqref{eq:sde-discrete}$意味着条件概率为
\begin{equation}\begin{aligned}
p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t) =&\, \mathcal{N}\left(\boldsymbol{x}_{t+\Delta t};\boldsymbol{x}_t + \boldsymbol{f}_t(\boldsymbol{x}_t) \Delta t, g_t^2\Delta t \,\boldsymbol{I}\right)\\
\propto&\, \exp\left(-\frac{\Vert\boldsymbol{x}_{t+\Delta t} - \boldsymbol{x}_t - \boldsymbol{f}_t(\boldsymbol{x}_t) \Delta t\Vert^2}{2 g_t^2\Delta t}\right)
\end{aligned}\label{eq:sde-proba}\end{equation}
简单起见,这里没有写出无关紧要的归一化因子。按照DDPM的思想,我们最终是想要从“拆楼”的过程中学会“建楼”,即得到$p(\boldsymbol{x}_t|\boldsymbol{x}_{t+\Delta t})$,为此,我们像《生成扩散模型漫谈(三):DDPM = 贝叶斯 + 去噪》一样,用贝叶斯定理:
\begin{equation}\begin{aligned}
p(\boldsymbol{x}_t|\boldsymbol{x}_{t+\Delta t}) =&\, \frac{p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t)p(\boldsymbol{x}_t)}{p(\boldsymbol{x}_{t+\Delta t})} = p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t) \exp\left(\log p(\boldsymbol{x}_t) - \log p(\boldsymbol{x}_{t+\Delta t})\right)\\
\propto&\, \exp\left(-\frac{\Vert\boldsymbol{x}_{t+\Delta t} - \boldsymbol{x}_t - \boldsymbol{f}_t(\boldsymbol{x}_t) \Delta t\Vert^2}{2 g_t^2\Delta t} + \log p(\boldsymbol{x}_t) - \log p(\boldsymbol{x}_{t+\Delta t})\right)
\end{aligned}\label{eq:bayes-dt}\end{equation}
不难发现,当$\Delta t$足够小时,只有当$\boldsymbol{x}_{t+\Delta t}$与$\boldsymbol{x}_t$足够接近时,$p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t)$才会明显不等于0,反过来也只有这种情况下$p(\boldsymbol{x}_t|\boldsymbol{x}_{t+\Delta t})$才会明显不等于0。因此,我们只需要对$\boldsymbol{x}_{t+\Delta t}$与$\boldsymbol{x}_t$足够接近时的情形做近似分析,为此,我们可以用泰勒展开:
\begin{equation}\log p(\boldsymbol{x}_{t+\Delta t})\approx \log p(\boldsymbol{x}_t) + (\boldsymbol{x}_{t+\Delta t} - \boldsymbol{x}_t)\cdot \nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) + \Delta t \frac{\partial}{\partial t}\log p(\boldsymbol{x}_t)\end{equation}
注意不要忽略了$\frac{\partial}{\partial t}$项,因为$p(\boldsymbol{x}_t)$实际上是“$t$时刻随机变量等于$\boldsymbol{x}_t$的概率密度”,而$p(\boldsymbol{x}_{t+\Delta t})$实际上是“$t+\Delta t$时刻随机变量等于$\boldsymbol{x}_{t+\Delta t}$的概率密度”,也就是说$p(\boldsymbol{x}_t)$实际上同时是$t$和$\boldsymbol{x}_t$的函数,所以要多一项$t$的偏导数。代入到式$\eqref{eq:bayes-dt}$后,配方得到
\begin{equation}p(\boldsymbol{x}_t|\boldsymbol{x}_{t+\Delta t}) \propto \exp\left(-\frac{\Vert\boldsymbol{x}_{t+\Delta t} - \boldsymbol{x}_t - \left[\boldsymbol{f}_t(\boldsymbol{x}_t) - g_t^2\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) \right]\Delta t\Vert^2}{2 g_t^2\Delta t} + \mathcal{O}(\Delta t)\right)\end{equation}
当$\Delta t\to 0$时,$\mathcal{O}(\Delta t)\to 0$不起作用,因此
\begin{equation}\begin{aligned}
p(\boldsymbol{x}_t|\boldsymbol{x}_{t+\Delta t}) \propto&\, \exp\left(-\frac{\Vert\boldsymbol{x}_{t+\Delta t} - \boldsymbol{x}_t - \left[\boldsymbol{f}_t(\boldsymbol{x}_t) - g_t^2\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) \right]\Delta t\Vert^2}{2 g_t^2\Delta t}\right) \\
\approx&\,\exp\left(-\frac{\Vert \boldsymbol{x}_t - \boldsymbol{x}_{t+\Delta t} + \left[\boldsymbol{f}_{t+\Delta t}(\boldsymbol{x}_{t+\Delta t}) - g_{t+\Delta t}^2\nabla_{\boldsymbol{x}_{t+\Delta t}}\log p(\boldsymbol{x}_{t+\Delta t}) \right]\Delta t\Vert^2}{2 g_{t+\Delta t}^2\Delta t}\right)
\end{aligned}\end{equation}
即$p(\boldsymbol{x}_t|\boldsymbol{x}_{t+\Delta t})$近似一个均值为$\boldsymbol{x}_{t+\Delta t} - \left[\boldsymbol{f}_{t+\Delta t}(\boldsymbol{x}_{t+\Delta t}) - g_{t+\Delta t}^2\nabla_{\boldsymbol{x}_{t+\Delta t}}\log p(\boldsymbol{x}_{t+\Delta t}) \right]\Delta t$、协方差为$g_{t+\Delta t}^2\Delta t\,\boldsymbol{I}$的正态分布,取$\Delta t\to 0$的极限,那么对应于SDE:
\begin{equation}d\boldsymbol{x} = \left[\boldsymbol{f}_t(\boldsymbol{x}) - g_t^2\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x}) \right] dt + g_t d\boldsymbol{w}\label{eq:reverse-sde}\end{equation}
这就是反向过程对应的SDE,最早出现在《Reverse-Time Diffusion Equation Models》中。这里我们特意在$p$处标注了下标$t$,以突出这是$t$时刻的分布。
得分匹配 #
现在我们已经得到了逆向的SDE为$\eqref{eq:reverse-sde}$,如果进一步知道$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$,那么就可以通过离散化格式
\begin{equation}\boldsymbol{x}_t - \boldsymbol{x}_{t+\Delta t} = - \left[\boldsymbol{f}_{t+\Delta t}(\boldsymbol{x}_{t+\Delta t}) - g_{t+\Delta t}^2\nabla_{\boldsymbol{x}_{t+\Delta t}}\log p(\boldsymbol{x}_{t+\Delta t}) \right]\Delta t - g_{t+\Delta t} \sqrt{\Delta t}\boldsymbol{\varepsilon}\label{eq:reverse-sde-discrete}\end{equation}
来逐步完成“建楼”的生成过程【其中$\boldsymbol{\varepsilon}\sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})$】,从而完成一个生成扩散模型的构建。
那么如何得到$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$呢?$t$时刻的$p_t(\boldsymbol{x})$就是前面的$p(\boldsymbol{x}_t)$,它的含义就是$t$时刻的边缘分布。在实际使用时,我们一般会设计能找到$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$解析解的模型,这意味着
\begin{equation}\small p(\boldsymbol{x}_t|\boldsymbol{x}_0) = \lim_{\Delta t\to 0}\int\cdots\iint p(\boldsymbol{x}_t|\boldsymbol{x}_{t-\Delta t})p(\boldsymbol{x}_{t-\Delta t}|\boldsymbol{x}_{t-2\Delta t})\cdots p(\boldsymbol{x}_{\Delta t}|\boldsymbol{x}_0) d\boldsymbol{x}_{t-\Delta t} d\boldsymbol{x}_{t-2\Delta t}\cdots d\boldsymbol{x}_{\Delta t}\end{equation}
是可以直接求出的,比如当$\boldsymbol{f}_t(\boldsymbol{x})$是关于$\boldsymbol{x}$的线性函数时,$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$就可以解析求解。在此前提下,有
\begin{equation}p(\boldsymbol{x}_t) = \int p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)d\boldsymbol{x}_0=\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]\end{equation}
于是
\begin{equation}\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t) = \frac{\mathbb{E}_{\boldsymbol{x}_0}\left[\nabla_{\boldsymbol{x}_t} p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]}{\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]} = \frac{\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]}{\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]}\end{equation}
可以看到最后的式子具有“$\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)$的加权平均”的形式,由于假设了$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$有解析解,因此上式实际上是能够直接估算的,然而它涉及到对全体训练样本$\boldsymbol{x}_0$的平均,一来计算量大,二来泛化能力也不够好。因此,我们希望用神经网络学一个函数$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$,使得它能够直接计算$\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)$。
很多读者应该对如下结果并不陌生(或者推导一遍也不困难):
\begin{equation}\mathbb{E}[\boldsymbol{x}] = \mathop{\text{argmin}}_{\boldsymbol{\mu}}\mathbb{E}_{\boldsymbol{x}}\left[\Vert \boldsymbol{\mu} - \boldsymbol{x}\Vert^2\right]\end{equation}
即要让$\boldsymbol{\mu}$等于$\boldsymbol{x}$的均值,只需要最小化$\Vert \boldsymbol{\mu} - \boldsymbol{x}\Vert^2$的均值。同理,要让$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$等于$\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)$的加权平均【即$\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)$】,则只需要最小化$\left\Vert \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t) - \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2$的加权平均,即
\begin{equation} \frac{\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\left\Vert \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t) - \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2\right]}{\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]}\end{equation}
分母的$\mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]$只是起到调节Loss权重的作用,简单起见我们可以直接去掉它,这不会影响最优解的结果。最后我们再对$\boldsymbol{x}_t$积分(相当于对于每一个$\boldsymbol{x}_t$都要最小化上述损失),得到最终的损失函数
\begin{equation}\begin{aligned}&\,\int \mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\left\Vert \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t) - \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2\right] d\boldsymbol{x}_t \\
=&\, \mathbb{E}_{\boldsymbol{x}_0,\boldsymbol{x}_t \sim p(\boldsymbol{x}_t|\boldsymbol{x}_0)\tilde{p}(\boldsymbol{x}_0)}\left[\left\Vert \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t) - \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2\right]
\end{aligned}\label{eq:score-match}\end{equation}
这就是“(条件)得分匹配”的损失函数,之前我们在《从去噪自编码器到生成模型》推导的去噪自编码器的解析解,也是它的一个特例。得分匹配的最早出处可以追溯到2005年的论文《Estimation of Non-Normalized Statistical Models by Score Matching》,至于条件得分匹配的最早出处,笔者追溯到的是2011年的论文《A Connection Between Score Matching and Denoising Autoencoders》。
不过,虽然该结果跟得分匹配是一样的,但其实在这一节的推导中,我们已经抛开了“得分”的概念了,纯粹是由目标自然地引导出来的答案,笔者认为这样的处理过程更有启发性,希望这一推导能降低大家对得分匹配的理解难度。
结果倒推 #
至此,我们构建了生成扩散模型的一般流程:
1、通过随机微分方程$\eqref{eq:sde-forward}$定义“拆楼”(前向过程);
2、求$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$的表达式;
3、通过损失函数$\eqref{eq:score-match}$训练$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$(得分匹配);
4、用$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$替换式$\eqref{eq:reverse-sde}$的$\nabla_{\boldsymbol{x}}\log p_t(\boldsymbol{x})$,完成“建楼”(反向过程)。
可能大家看到SDE、微分方程等字眼,天然就觉得“恐慌”,但本质上来说,SDE只是个“幌子”,实际上将对SDE的理解转换到式$\eqref{eq:sde-discrete}$和式$\eqref{eq:sde-proba}$上后,完全就可以抛开SDE的概念了,因此概念上其实是没有太大难度的。
不难发现,定义一个随机微分方程$\eqref{eq:sde-forward}$是很容易的,但是从$\eqref{eq:sde-forward}$求解$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$却是不容易的。原论文的剩余篇幅,主要是对两个有实用性的例子推导和实验。然而,既然求解$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$不容易,那么按照笔者的看法,与其先定义$\eqref{eq:sde-forward}$再求解$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$,倒不如像DDIM一样,先定义$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$,然后再来反推对应的SDE?
例如,我们先定义
\begin{equation} p(\boldsymbol{x}_t|\boldsymbol{x}_0) = \mathcal{N}(\boldsymbol{x}_t; \bar{\alpha}_t \boldsymbol{x}_0,\bar{\beta}_t^2 \boldsymbol{I})\end{equation}
并且不失一般性假设起点是$t=0$,终点是$t=1$,那么$\bar{\alpha}_t,\bar{\beta}_t$要满足的边界就是
\begin{equation} \bar{\alpha}_0 = 1,\quad \bar{\alpha}_1 = 0,\quad \bar{\beta}_0 = 0,\quad \bar{\beta}_1 = 1\end{equation}
当然,上述边界条件理论上足够近似就行,也不一定非要精确相等,比如上一篇文章我们分析过DDPM相当于选择了$\bar{\alpha}_t = e^{-5t^2}$,当$t=1$时结果为$e^{-5}\approx 0$。
有了$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$,我们去反推$\eqref{eq:sde-forward}$,本质上就是要求解$p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t)$,它要满足
\begin{equation} p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_0) = \int p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t) p(\boldsymbol{x}_t|\boldsymbol{x}_0) d\boldsymbol{x}_t\end{equation}
我们考虑线性的解,即
\begin{equation}d\boldsymbol{x} = f_t\boldsymbol{x} dt + g_t d\boldsymbol{w}\end{equation}
跟《生成扩散模型漫谈(四):DDIM = 高观点DDPM》一样,我们写出
\begin{array}{c|c|c}
\hline
\text{记号} & \text{含义} & \text{采样}\\
\hline
p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_0) & \mathcal{N}(\boldsymbol{x}_t;\bar{\alpha}_{t+\Delta t} \boldsymbol{x}_0,\bar{\beta}_{t+\Delta t}^2 \boldsymbol{I}) & \boldsymbol{x}_{t+\Delta t} = \bar{\alpha}_{t+\Delta t} \boldsymbol{x}_0 + \bar{\beta}_{t+\Delta t} \boldsymbol{\varepsilon} \\
\hline
p(\boldsymbol{x}_t|\boldsymbol{x}_0) & \mathcal{N}(\boldsymbol{x}_t;\bar{\alpha}_t \boldsymbol{x}_0,\bar{\beta}_t^2 \boldsymbol{I}) & \boldsymbol{x}_t = \bar{\alpha}_t \boldsymbol{x}_0 + \bar{\beta}_t \boldsymbol{\varepsilon}_1 \\
\hline
p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t) & \mathcal{N}(\boldsymbol{x}_{t+\Delta t}; (1 + f_t\Delta t) \boldsymbol{x}_t, g_t^2 \Delta t\, \boldsymbol{I}) & \boldsymbol{x}_{t+\Delta t} = (1 + f_t\Delta t) \boldsymbol{x}_t + g_t\sqrt{\Delta t}\boldsymbol{\varepsilon}_2 \\
\hline
{\begin{array}{c}\int p(\boldsymbol{x}_{t+\Delta t}|\boldsymbol{x}_t) \\
p(\boldsymbol{x}_t|\boldsymbol{x}_0) d\boldsymbol{x}_t\end{array}} & & {\begin{aligned}&\,\boldsymbol{x}_{t+\Delta t} \\
=&\, (1 + f_t\Delta t) \boldsymbol{x}_t + g_t\sqrt{\Delta t} \boldsymbol{\varepsilon}_2 \\
=&\, (1 + f_t\Delta t) (\bar{\alpha}_t \boldsymbol{x}_0 + \bar{\beta}_t \boldsymbol{\varepsilon}_1) + g_t\sqrt{\Delta t} \boldsymbol{\varepsilon}_2 \\
=&\, (1 + f_t\Delta t) \bar{\alpha}_t \boldsymbol{x}_0 + ((1 + f_t\Delta t)\bar{\beta}_t \boldsymbol{\varepsilon}_1 + g_t\sqrt{\Delta t} \boldsymbol{\varepsilon}_2) \\
\end{aligned}} \\
\hline
\end{array}
由此可得
\begin{equation}\begin{aligned}
\bar{\alpha}_{t+\Delta t} =&\, (1 + f_t\Delta t) \bar{\alpha}_t \\
\bar{\beta}_{t+\Delta t}^2 =&\, (1 + f_t\Delta t)^2\bar{\beta}_t^2 + g_t^2\Delta t
\end{aligned}\end{equation}
令$\Delta t\to 0$,分别解得
\begin{equation}
f_t = \frac{d}{dt} \left(\ln \bar{\alpha}_t\right) = \frac{1}{\bar{\alpha}_t}\frac{d\bar{\alpha}_t}{dt}, \quad g_t^2 = \bar{\alpha}_t^2 \frac{d}{dt}\left(\frac{\bar{\beta}_t^2}{\bar{\alpha}_t^2}\right) = 2\bar{\alpha}_t \bar{\beta}_t \frac{d}{dt}\left(\frac{\bar{\beta}_t}{\bar{\alpha}_t}\right)\end{equation}
取$\bar{\alpha}_t\equiv 1$时,结果就是论文中的VE-SDE(Variance Exploding SDE);而如果取$\bar{\alpha}_t^2 + \bar{\beta}_t^2=1$时,结果就是原论文中的VP-SDE(Variance Preserving SDE)。
至于损失函数,此时我们可以算得
\begin{equation}\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0) = -\frac{\boldsymbol{x}_t - \bar{\alpha}_t\boldsymbol{x}_0}{\bar{\beta}_t^2}=-\frac{\boldsymbol{\varepsilon}}{\bar{\beta}_t}\end{equation}
第二个等号是因为$\boldsymbol{x}_t = \bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}$,为了跟以往的结果对齐,我们设$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t) = -\frac{\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)}{\bar{\beta}_t}$,此时式$\eqref{eq:score-match}$为
\begin{equation}\frac{1}{\bar{\beta}_t^2}\mathbb{E}_{\boldsymbol{x}_0\sim \tilde{p}(\boldsymbol{x}_0),\boldsymbol{\varepsilon}\sim\mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\left[\left\Vert \boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\bar{\alpha}_t\boldsymbol{x}_0 + \bar{\beta}_t\boldsymbol{\varepsilon}, t) - \boldsymbol{\varepsilon}\right\Vert^2\right]\end{equation}
忽略系数后就是DDPM的损失函数,而用$-\frac{\boldsymbol{\epsilon}_{\boldsymbol{\theta}}(\boldsymbol{x}_{t+\Delta t}, t+\Delta t)}{\bar{\beta}_{t+\Delta t}}$替换掉式$\eqref{eq:reverse-sde-discrete}$的$\nabla_{\boldsymbol{x}_{t+\Delta t}}\log p(\boldsymbol{x}_{t+\Delta t})$后,结果与DDPM的采样过程具有相同的一阶近似(意味着$\Delta t\to 0$时两者等价)。
文章小结 #
本文主要介绍了宋飏博士建立的利用SDE理解扩散模型的一般框架,其中包括以尽可能直观的语言推导了反向SDE、得分匹配等结果,并对方程的求解给出了自己的想法。
转载到请包括本文地址:https://kexue.fm/archives/9209
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Aug. 03, 2022). 《生成扩散模型漫谈(五):一般框架之SDE篇 》[Blog post]. Retrieved from https://kexue.fm/archives/9209
@online{kexuefm-9209,
title={生成扩散模型漫谈(五):一般框架之SDE篇},
author={苏剑林},
year={2022},
month={Aug},
url={\url{https://kexue.fm/archives/9209}},
}
October 11th, 2024
请问一下式21和SDE原文里的式11等价吗?推了一下感觉有点对不上,但又感觉两边都对。难道同一个递推公式能对应不同的SDE?
我看了一下你说的原论文公式是
$$d\boldsymbol{x} = -\frac{1}{2}\beta(t)\boldsymbol{x} dt + \sqrt{\beta(t)} d\boldsymbol{w}$$
你的问题是什么?怎么推出这个公式对应的$\bar{\alpha}_t,\bar{\beta}_t$还是啥?什么叫跟$(21)$对不上?
就是如果按照原文的记号,把您的式21的$\bar{\alpha}_t$换成$\sqrt{\bar{\alpha}_t}$的话,会得到$f(t) = \frac1{2\bar{\alpha}}\frac{d\bar{\alpha}}{dt}$,和原文的$f(x,t)=-\frac12\beta(t)x$似乎并不等价,这是为什么呢?
$(21)$式描述的是$f_t,g_t$与$\bar{\alpha}_t,\bar{\beta}_t$的关系,$d\boldsymbol{x} = -\frac{1}{2}\beta(t)\boldsymbol{x} dt + \sqrt{\beta(t)} d\boldsymbol{w}$说的就是$f_t = -\frac{1}{2}\beta(t),g_t = \sqrt{\beta(t)}$,没牵扯到$\bar{\alpha}_t,\bar{\beta}_t$呀。
但是$\alpha_t$和$\bar{\alpha}_t$之间是有关系的啊,连续情况下应该有$\log\bar{\alpha}(t)=\log\int_0^t \alpha(t)dt$。以及原文里对应的应该是$f(x,t)=\frac12\beta(t)x$而不是$f(t)=-\frac12\beta(t)$吧。
SDE框架里边,没有$\alpha_t,\beta_t$,只有$f_t,g_t$(线性SDE即$(19)$的函数)和$\bar{\alpha}_t,\bar{\beta}_t$(即$p(\boldsymbol{x}_t|\boldsymbol{x}_0)$的参数)。
根据$(21)$有$\ln \bar{\alpha}_t = \int f_t dt$,你是不是想表达$f_t = \ln \alpha_t$,从而有$\ln \bar{\alpha}_t = \int \ln \alpha_t dt$?
那如果换一种方式问,把SDE论文里的DDPM的SDE也就是式11,写出它对应的PF-ODE,会得到$dx/dt = -\frac12 \beta(t) (1+\nabla_x \log p(x))$,这个式子是否是DDIM呢?
@QTB|comment-25636
在ODE篇已经证明过了:https://kexue.fm/archives/9228#%E5%9B%9E%E9%A1%BEDDIM
根据这篇确实推出来二者等价了,谢谢解答!
October 17th, 2024
$p(x_{t})$写成$p_{t}(x_{t})$以区分$t$和$t+\Delta t$时候的概率分布,这样写会不会更好一些
确实是这样,后面的一些文章中已经特意做了区分,不过就本文而言,应该还不至于引起严重的混淆,因此简单起见还是保持省略了。
October 25th, 2024
苏老师您好,
关于 (14) 到 $\eqref{eq:score-match}$ 式,我想到了一种推导方式,可以无需忽略 (14) 分母的 $ \mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]$,推导更自然,想和苏老师分享一下
事实上,可以证明一个比 (13) 式的更强的结论:
$$ \begin{aligned}
\forall \boldsymbol{f}(\boldsymbol{x}), g(\boldsymbol{x}), \quad \mathrm{argmin}_\boldsymbol{\mu} \mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \cdot \left\Vert
\boldsymbol{\mu} - \boldsymbol{f}(\boldsymbol{x})
\right\Vert^2 \right] = \frac { \mathbb{E}_{\boldsymbol{x}}[g(\boldsymbol{x}) \cdot \boldsymbol{f}(\boldsymbol{x})] }{ \mathbb{E}_{\boldsymbol{x}}[g(\boldsymbol{x})] }
\end{aligned}$$
证明方式是:
$$ \begin{aligned} & \mathrm{argmin}_\boldsymbol{\mu} \mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \cdot \left\Vert
\boldsymbol{\mu} - \boldsymbol{f}(\boldsymbol{x})
\right\Vert^2 \right] \\
= & \mathrm{argmin}_\boldsymbol{\mu} \mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \cdot \left( \left\Vert
\boldsymbol{\mu} \right\Vert^2 - 2 \boldsymbol{\mu} \cdot \boldsymbol{f}(\boldsymbol{x}) + \left\Vert
\boldsymbol{ \boldsymbol{f}(\boldsymbol{x})} \right\Vert^2 \right) \right] \\
= & \mathrm{argmin}_\boldsymbol{\mu} \mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \cdot \left( \left\Vert
\boldsymbol{\mu} \right\Vert^2 - 2 \boldsymbol{\mu} \cdot \boldsymbol{f}(\boldsymbol{x}) \right) \right] \\
= & \mathrm{argmin}_\boldsymbol{\mu} \left( \left\Vert
\boldsymbol{\mu} \right\Vert^2 \cdot \mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \right] - 2 \boldsymbol{\mu} \cdot \mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x})\cdot\boldsymbol{f}(\boldsymbol{x}) \right] \right) \\
= & \mathrm{argmin}_\boldsymbol{\mu} \left( \left\Vert
\boldsymbol{\mu} \right\Vert^2 - 2 \boldsymbol{\mu} \cdot \frac{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x})\cdot\boldsymbol{f}(\boldsymbol{x}) \right]}{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \right]} \right) \\
= & \mathrm{argmin}_\boldsymbol{\mu} \left( \left\Vert
\boldsymbol{\mu} \right\Vert^2 - 2 \boldsymbol{\mu} \cdot \frac{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x})\cdot\boldsymbol{f}(\boldsymbol{x}) \right]}{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \right]} +
\left\Vert \frac{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x})\cdot\boldsymbol{f}(\boldsymbol{x}) \right]}{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \right]}\right\Vert^2 \right) \\
= & \mathrm{argmin}_\boldsymbol{\mu} \left\Vert
\boldsymbol{\mu} - \frac{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x})\cdot\boldsymbol{f}(\boldsymbol{x}) \right]}{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \right]}\right\Vert^2 \\
= & \frac{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x})\cdot\boldsymbol{f}(\boldsymbol{x}) \right]}{\mathbb{E}_{\boldsymbol{x}}\left[ g(\boldsymbol{x}) \right]}
\end{aligned}$$
有了这个结论后,我们令 $\boldsymbol{x}$ 为 $\boldsymbol{x_0}$, $\boldsymbol{\mu}$ 为 $\boldsymbol{s}_{\boldsymbol\theta}(\boldsymbol{x}_t, t)$, $\boldsymbol{f}(\boldsymbol{x})$ 为 $\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)$, ${g}(\boldsymbol{x})$ 为 $ p(\boldsymbol{x}_t|\boldsymbol{x}_0)$,即可得到 $\eqref{eq:score-match}$ 的最优值为 $\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)$,而且无需忽略 (14) 分母的 $ \mathbb{E}_{\boldsymbol{x}_0}\left[p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right]$
最后,再次谢谢苏老师精彩的介绍,收获满满
谢谢!其实你这就是把“要让$\boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t)$等于$\nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)$的加权平均【即$\nabla_{\boldsymbol{x}_t}\log p(\boldsymbol{x}_t)$】,则只需要最小化$\left\Vert \boldsymbol{s}_{\boldsymbol{\theta}}(\boldsymbol{x}_t, t) - \nabla_{\boldsymbol{x}_t} \log p(\boldsymbol{x}_t|\boldsymbol{x}_0)\right\Vert^2$的加权平均”证明了一下,那个分母的scale本身就不重要,就好比你这里的目标函数乘上一个scale也不改变结果,所以其实也没多不自然。
October 28th, 2024
你的表格“采样”列第四行的公式就是原论文公式24,所以按照我理解你的21里解出来的解应该满足:$f_t = \frac{d}{dt} \left(\ln \bar{\alpha}_t\right) =-\frac{1}{2}\beta(t
)$。但是我看不出来这第二个等号如何成立。我用了$\bar{\alpha}_t=\alpha_t\alpha_{t-1}...\alpha_1$,也推不出来。能不能请你再解释清楚一些。
同@苏剑林|comment-25536,SDE框架里边,没有$\alpha_t,\beta_t$,只有$f_t,g_t$。你所举的例子是$f_t = -\frac{1}{2}\beta(t)$,根据$(21)$有$f_t = \frac{d}{dt}\ln\bar{\alpha}_t$,所以$f_t = \frac{d}{dt}\ln\bar{\alpha}_t=-\frac{1}{2}\beta(t)$。
你的问题在于,你总是想将$\beta(t)$映射为你脑子里心心念念但又含糊不清的$\beta(t)$。
October 28th, 2024
我又看了一下,你的$\bar{\alpha}_t$ 实际上是在 ddpm 原论文里 $\sqrt{\bar{\alpha}_t}$。所以如果还用ddpm原论文里的符号,应该是$f_t = \frac{d}{dt} \left(\ln \sqrt{\bar{\alpha}_t}\right) =\frac{1}{2}\frac{d}{dt} \left(\ln \bar{\alpha}_t\right) \rightarrow-\frac{1}{2}\beta(t
)$ 。而离散化后 $\ln \bar{\alpha}_t-\ln \bar{\alpha}_{t-1}=\ln \alpha_t=\ln(1-\beta_t
)\rightarrow -\beta_t$ ?感觉这里用到了 $\ln(1-x)=-x$ ?
如果你想要的是离散化对应,那么$d\boldsymbol{x} = f_t\boldsymbol{x} dt + g_t d\boldsymbol{w}$的离散化结果是
$$\boldsymbol{x}_{t+\Delta t}= (1 + f_t\Delta t)\boldsymbol{x}_t + g_t \sqrt{\Delta t}\boldsymbol{\varepsilon},\quad \boldsymbol{\varepsilon}\sim \mathcal{N}(\boldsymbol{0}, \boldsymbol{I})$$
那这样看来,$1 + f_t\Delta t$才是你心心念念的那个$\alpha_t$,$g_t \sqrt{\Delta t}$则是你心心念念的那个$\beta_t$。
当取$f_t = -\frac{1}{2}\beta(t),g_t = \beta(t)$时,满足
$$(1 + f_t\Delta t)^2 + (g_t\sqrt{\Delta t})^2 = 1 + \mathcal{O}(\Delta t^2)$$
即在一阶近似内满足平方和等于1。
我明白你说的在连续模型当中,平方和等于1的等价于vp sde。其实我一直费解的就是你说的这个离散化对应。是不是离散化对应就意味着令 $\Delta t=1$?
主要就是,原论文中讲 ddpm 一会儿把这个系数写成 $-\frac{1}{2}\beta(t)$,一会儿又写成 $\sqrt{1-\beta(t)}-1$ ,把我搞得比较迷。我完全理解这两个式子取极限的时候趋于相等,但不理解作者为啥要搞出多种形式,为什么不统一形式,是纯属疏漏还是另有深意,这是我困惑点。
我可能明白了,$f_t=-\frac{1}{2}\beta(t)$ 里的 $\beta(t)$ 和 $\beta_t$ 没有关系,前者的 $\beta(t)$ 可以是任意关于 $t$ 的函数。这是个等价关系:$g_t=\sqrt{-2f_t}\Leftrightarrow \bar\alpha_t^2+\bar\beta_t^2=1$ 。也就是说,微分方程里的系数满足 $g_t=\sqrt{-2f_t}$ 等价于加噪转移核 $p(X_t|X_0)$ 的均值和方差前面的系数的平方和为 1。
再具体点说,只要把加噪转移核的均值和方差设计成它们的平方和为 1,就对应 VP SDE ;只要把均值设置为 $X_t+0$ ,方差设置为任意 $\sigma(t)^2I$ ,就对应 VE SDE 。如你说 SDE 只是个“走个形式”。
是的。非要对应的话,$\beta(t)\sqrt{\Delta t}$才是$\beta_t$。
November 2nd, 2024
苏神,公式(9)和知乎上似乎不太一样,右端第二项应该是正的吧,这也符合原文中的表述。我的理解是不是这里的SDE由于是反向的,所以用欧拉离散的时候直接写成$dx = x_{t} - x_{t + \delta{t}}$的形式,并且$dt=-\delta{t}$,而根号里面直接取绝对值就好。不知道对不对。
你说$- g_{t+\Delta t} \sqrt{\Delta t}\boldsymbol{\varepsilon}$这一项吗?这里$\boldsymbol{\varepsilon}$是标准正态分布的采样结果,$-\boldsymbol{\varepsilon}$同样服从标准正态分布,所以这里其实正负结果都等价的。
December 4th, 2024
苏老师,我想向您确认一下我的理解,公式(19)中考虑线性解的意思,是否就是$f_t(x)$直接换成了$f_t(x)=f_t\cdot x$?因为一开始我理解$f_t(\cdot)$是一个函数,这里却直接当作线性系数与$x$相乘了,一时间没转换过来.
在下一篇看到解释了,确实是这样的!
$\boldsymbol{f}_t(\boldsymbol{x})$是一个函数,后面计算的时候,则取$\boldsymbol{f}_t(\boldsymbol{x})=f_t\boldsymbol{x}$这个特例。
December 13th, 2024
请问苏神,(5)代入(4)怎么配方得到6的
就直接代入配方啊,没什么技巧。或者你从$(6)$出发反推了。