生成扩散模型漫谈(十九):作为扩散ODE的GAN
By 苏剑林 | 2023-06-24 | 33817位读者 |在文章《生成扩散模型漫谈(十六):W距离 ≤ 得分匹配》中,我们推导了Wasserstein距离与扩散模型得分匹配损失之间的一个不等式,表明扩散模型的优化目标与WGAN的优化目标在某种程度上具有相似性。而在本文,我们将探讨《MonoFlow: Rethinking Divergence GANs via the Perspective of Wasserstein Gradient Flows》中的研究成果,它进一步展示了GAN与扩散模型之间的联系:GAN实际上可以被视为在另一个时间维度上的扩散ODE!
这些发现表明,尽管GAN和扩散模型表面上是两种截然不同的生成式模型,但它们实际上存在许多相似之处,并在许多方面可以相互借鉴和参考。
思路简介 #
我们知道,GAN所训练的生成器是从噪声$\boldsymbol{z}$到真实样本的一个直接的确定性变换$\boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z})$,而扩散模型的显著特点是“渐进式生成”,它的生成过程对应于从一系列渐变的分布$p_0(\boldsymbol{x}_0),p_1(\boldsymbol{x}_1),\cdots,p_T(\boldsymbol{x}_T)$中采样(注:在前面十几篇文章中,$\boldsymbol{x}_T$是噪声,$\boldsymbol{x}_0$是目标样本,采样过程是$\boldsymbol{x}_T\to \boldsymbol{x}_0$,但为了便于下面的表述,这里反过来改为$\boldsymbol{x}_0\to \boldsymbol{x}_T$)。看上去确实找不到多少相同之处,那怎么才能将两者联系起来呢?
很明显,如果想要从扩散模型的视角理解GAN,那么就要想办法构造出一系列渐变的分布出来。生成器$\boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z})$本身就是一个一步到位的变换,不存在渐变,然而我们知道模型的优化是渐变的,可否用参数$\boldsymbol{\theta}$的历史轨迹$\boldsymbol{\theta}_t$来构建这一系列渐变分布呢?具体来说,假设生成器初始化为$\boldsymbol{\theta}_0$,经过$T$步对抗训练后得到最优参数$\boldsymbol{\theta}_T$,训练过程的中间参数为$\boldsymbol{\theta}_1,\boldsymbol{\theta}_2,\cdots,\boldsymbol{\theta}_{T-1}$,那么我们定义$\boldsymbol{x}_t = \boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z})$,不就定义了一系列渐变的$\boldsymbol{x}_0,\boldsymbol{x}_1,\cdots,\boldsymbol{x}_T$,从而也就定义了渐变的分布$p_0(\boldsymbol{x}_0),p_1(\boldsymbol{x}_1),\cdots,p_T(\boldsymbol{x}_T)$了?
如果这个思路可行的话,那么GAN就可以诠释为梯度下降的(虚拟)时间维度上的扩散模型!下面我们就沿着这个思路进行探索。
梯度之流 #
首先,我们需要重温上一篇文章《梯度流:探索通向最小值之路》关于Wasserstein梯度流的结果:它指出方程
\begin{equation}\frac{\partial q_t(\boldsymbol{x})}{\partial t} = - \nabla_{\boldsymbol{x}}\cdot\big(q_t(\boldsymbol{x})\nabla_{\boldsymbol{x}}\log r_t(\boldsymbol{x})\big)\label{eq:w-flow}\end{equation}
在最小化$p(\boldsymbol{x})$和$q_t(\boldsymbol{x})$的KL散度,即$\lim\limits_{t\to\infty} q_t(\boldsymbol{x}) = p(\boldsymbol{x})$,这里$r_t(\boldsymbol{x})=\frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}$。如果$p(\boldsymbol{x})$代表真实样本的分布,那么如果能实现从$q_t(\boldsymbol{x})$采样的话,那么逐渐推到$t\to\infty$时,就可以实现从$p(\boldsymbol{x})$采样了。根据《测试函数法推导连续性方程和Fokker-Planck方程》,从$q_t(\boldsymbol{x})$采样可以通过下述ODE实现:
\begin{equation}\frac{d\boldsymbol{x}}{dt} = \nabla_{\boldsymbol{x}}\log r_t(\boldsymbol{x})\label{eq:ode-core}\end{equation}
然而,上式中的$r_t(\boldsymbol{x})$是未知的,所以我们还无法通过上式进行采样,需要先想办法估算$r_t(\boldsymbol{x})$。
判别估计 #
这时候登场的是GAN的判别器。以最早的Vanilla GAN为例,它的训练目标是
\begin{equation}\max_D\, \mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}[\log \sigma(D(\boldsymbol{x}))] + \mathbb{E}_{\boldsymbol{x}\sim q(\boldsymbol{x})}[\log (1 - \sigma(D(\boldsymbol{x})))]\label{eq:gan-d}\end{equation}
这里的$D$是判别器,$\sigma(t)=1/(1+e^{-t})$是sigmoid函数,$p(\boldsymbol{x})$是真样本的分布,$q(\boldsymbol{x})$是假样本的分布。可以证明(不清楚的读者可以参考《RSGAN:对抗模型中的“图灵测试”思想》中的“补充证明”一节),上式中判别器$D$的理论最优解是
\begin{equation}D(\boldsymbol{x}) = \log \frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}\end{equation}
更一般化的f-GAN(参考《f-GAN简介:GAN模型的生产车间》、《Designing GANs:又一个GAN生产车间》)结果会稍有不同,但可以证明的是它们判别器的理论最优解都是$\frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}$的函数。也就是说,只要我们可以实现从$p(\boldsymbol{x})$和$q_t(\boldsymbol{x})$中采样,那么通过GAN的判别器训练$\eqref{eq:gan-d}$就可以估算出$r_t(\boldsymbol{x})=\frac{p(\boldsymbol{x})}{q_t(\boldsymbol{x})}$出来。
向前一步 #
这时候可能有读者疑惑:这不就进入“鸡生蛋、蛋生鸡”的循环论证了吗?我们估算$r_t(\boldsymbol{x})$不就是为了利用式$\eqref{eq:ode-core}$实现从$q_t(\boldsymbol{x})$中采样吗?现在你又假设能从$q_t(\boldsymbol{x})$采样才来估算$r_t(\boldsymbol{x})$?不着急,经典的一笔就要来了。
假设我们有生成器$\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z})$,它的采样生成结果就等于从$q_t(\boldsymbol{x})$采样的结果,即
\begin{equation}\big\{\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z})\big|\,\boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})\big\}\quad = \quad\big\{\boldsymbol{x}_t\big|\,\boldsymbol{x}_t\sim q_t(\boldsymbol{x})\big\}\end{equation}
那么现在我们就可以利用它和式$\eqref{eq:gan-d}$来估算$r_t(\boldsymbol{x})$。注意这只是$t$时刻的$r_t(\boldsymbol{x})$,其他时刻的$r_t(\boldsymbol{x})$我们并不知道,所以无法直接通过式$\eqref{eq:ode-core}$完成最终的采样过程,但是我们可以往前推一小步:
\begin{equation}\boldsymbol{x}_{t+1} = \boldsymbol{x}_t + \epsilon \nabla_{\boldsymbol{x}_t}\log r_t(\boldsymbol{x}_t) = \boldsymbol{x}_t + \epsilon \nabla_{\boldsymbol{x}_t} D(\boldsymbol{x}_t)\label{eq:forward}\end{equation}
这里的$\epsilon$是一个很小的正数,代表步长。那么,现在我们就有了下一步采样的结果,我们希望它继续能等价于下一步的生成器的采样结果,即
\begin{equation}\begin{aligned}
\big\{\boldsymbol{g}_{\boldsymbol{\theta}_{t+1}}(\boldsymbol{z})\big|\,\boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})\big\}\quad =& \quad\big\{\boldsymbol{x}_{t+1}\big|\,\boldsymbol{x}_{t+1}\sim q_{t+1}(\boldsymbol{x})\big\} \\[5pt]
\quad =& \quad\big\{\boldsymbol{x}_{t+1}\big|\,\boldsymbol{x}_t + \epsilon \nabla_{\boldsymbol{x}_t} D(\boldsymbol{x}_t),\boldsymbol{x}_t\sim q_t(\boldsymbol{x})\big\}
\end{aligned}\end{equation}
换句话说,我们想要将扩散模型中样本的运动转化为生成器参数的运动!为了达到这个目标,我们通过如下损失去求$\boldsymbol{\theta}_{t+1}$:
\begin{equation}\boldsymbol{\theta}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})}\Big[\big\Vert \boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z}) - \boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}) - \epsilon \nabla_{\boldsymbol{g}}D(\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}))\big\Vert^2\Big]\label{eq:gan-g0}\end{equation}
也就是说,拿$\boldsymbol{x}_t = \boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z})$往前迭代一步得到$\boldsymbol{x}_{t+1}$,然后希望新的$\boldsymbol{g}_{\boldsymbol{\theta}_{t+1}}(\boldsymbol{z})$能尽量等于$\boldsymbol{x}_{t+1}$。完成这一轮后,再用$\boldsymbol{\theta}_{t+1}$替代原本的$\boldsymbol{\theta}_t$开始新一轮的迭代,也就是式$\eqref{eq:gan-d}$和式$\eqref{eq:gan-g0}$交替执行,是不是就有GAN的味道了?
点睛之笔 #
如果这还不够,我们还可以继续完善一下,将它变得跟GAN更加一致。注意到式$\eqref{eq:gan-g0}$的被期望函数的梯度是:
\begin{equation}\begin{aligned}
&\,\nabla_{\boldsymbol{\theta}}\Vert \boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z}) - \boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}) - \epsilon \nabla_{\boldsymbol{g}}D(\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}))\Vert^2 \\
=&\,2\big\langle\boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z}) - \boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}) - \epsilon \nabla_{\boldsymbol{g}}D(\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z})), \nabla_{\boldsymbol{\theta}}\boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z}) \big\rangle \\
\end{aligned}\end{equation}
代入当前值$\boldsymbol{\theta}=\boldsymbol{\theta}_t$,那么结果是
\begin{equation}-2\epsilon\big\langle \nabla_{\boldsymbol{g}}D(\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z})), \nabla_{\boldsymbol{\theta}_t}\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}) \big\rangle = -2\epsilon\nabla_{\boldsymbol{\theta}_t}D(\boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}))\end{equation}
也就是说,如果用基于梯度的优化器只优化一步的话,那么以式$\eqref{eq:gan-g0}$为损失函数,跟以下式为损失函数,结果是等价的(因为梯度只差一个常数倍):
\begin{equation}\boldsymbol{\theta}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})}[-D(\boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z}))]\label{eq:gan-g}\end{equation}
这便是常见的生成器损失之一。式$\eqref{eq:gan-d}$和式$\eqref{eq:gan-g}$交替训练,就是一个常见的GAN变体。
特别地,原论文还证明了生成器的损失函数可以一般化为
\begin{equation}\boldsymbol{\theta}_{t+1} = \mathop{\text{argmin}}_{\boldsymbol{\theta}}\mathbb{E}_{\boldsymbol{z}\sim \mathcal{N}(\boldsymbol{0},\boldsymbol{I})}[-h(D(\boldsymbol{g}_{\boldsymbol{\theta}}(\boldsymbol{z})))]\end{equation}
其中$h(\cdot)$是任意单调递增函数,它也对应于Wasserstein梯度流$\eqref{eq:w-flow}$中的$\log r_t(\boldsymbol{x})$可以换成$h(\log r_t(\boldsymbol{x}))$,这应该也是MonoFlow一词的来源(Monotonically increasing function + Wasserstein flow)。这个证明过程就不展开了,大家自行看原论文就好。
意义思考 #
总的来说,将GAN理解为扩散模型的思路是
$$\require{AMScd}\begin{CD}
\cdots @>\quad>> \boldsymbol{g}_{\boldsymbol{\theta}_t}(\boldsymbol{z}) @> 式\eqref{eq:gan-d} >> r_t(\boldsymbol{x}) @> 式\eqref{eq:forward}>> \boldsymbol{x}_{t+1} @> 式\eqref{eq:gan-g0}>> \boldsymbol{g}_{\boldsymbol{\theta}_{t+1}}(\boldsymbol{z})@>\quad>>\cdots
\end{CD}$$
其中,核心的式子是$\eqref{eq:forward}$,它源于Wasserstein梯度流的式$\eqref{eq:w-flow}$和式$\eqref{eq:ode-core}$,这部分我们在上一篇文章《梯度流:探索通向最小值之路》讨论过了。
可能有读者想问:这个视角看上去并没有得到比GAN更多的东西,为什么要费这番大功夫去重新理解GAN呢?首先,在笔者看来,从扩散模型角度理解GAN,或者说将扩散模型和GAN统一起来,它本身就是一件很有趣、很好玩的事情,并不一定需要发挥什么实际作用,有趣、好玩就是它最大的意义。
其次,如同作者在原论文所说,已有的GAN的推导过程跟它实际的训练过程是不一致的,而本文所讨论的扩散视角,则是跟训练过程是一致的。也就是说,以训练过程为标准的话,GAN已有的推导过程是错的,本文的扩散视角才是对的。怎么理解这一点呢?以前面提到的GAN为例,判别器和生成器的目标分别是:
\begin{gather}\max_D\, \mathbb{E}_{\boldsymbol{x}\sim p(\boldsymbol{x})}[\log \sigma(D(\boldsymbol{x}))] + \mathbb{E}_{\boldsymbol{x}\sim q(\boldsymbol{x})}[\log (1 - \sigma(D(\boldsymbol{x})))] \\
\min_q\mathbb{E}_{\boldsymbol{x}\sim q(\boldsymbol{x})}[-D(\boldsymbol{x})]
\end{gather}
通常的证明方式是,证明$D$的最优解是$\log\frac{p(\boldsymbol{x})}{q(\boldsymbol{x})}$,然后代入生成器的损失函数,发现它在最小化$q(\boldsymbol{x}),p(\boldsymbol{x})$的KL散度,所以最优解是$q(\boldsymbol{x})=p(\boldsymbol{x})$。但是,这样的证明对应的训练方式应该是先针对任意的$q(\boldsymbol{x})$,将$\max\limits_D$这一步都求解出来(求出的$D$应该是$q(x)$的函数,或者说应该是生成器参数$\boldsymbol{\theta}$的函数),然后再去执行$\min\limits_q$这一步,而不是实际上用的交替训练。而基于扩散模型的理解,它在设计上就是交替的,所以它跟训练过程更加一致。
总的来说,从扩散模型的角度来理解GAN,不单单是理解GAN的一种新视角,而且还是一种更贴近训练过程的视角。比如,我们可以解释为什么GAN的生成器不能训练太多步,因为只有单步优化时,式$\eqref{eq:gan-g}$和式$\eqref{eq:gan-g0}$才等价,如果GAN要进行更多步的优化,那么应该使用式$\eqref{eq:gan-g0}$为损失函数。事实上,式$\eqref{eq:gan-g0}$就相当于笔者之前在《用变分推断统一理解生成模型(VAE、GAN、AAE、ALI)》所提出的$KL\left(q(x)\Vert q^{o}(x)\right)$项,它保证了生成器的“传承”而不仅仅是“创新”。
文章小结 #
本文介绍了MonoFlow,它展示了我们可以将GAN理解为在另一个时间维度上的扩散ODE,从而建立了一种基于扩散模型理解GAN的新视角。特别地,这是一种比GAN的常规推导更加贴近训练过程的视角。
转载到请包括本文地址:https://kexue.fm/archives/9662
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 24, 2023). 《生成扩散模型漫谈(十九):作为扩散ODE的GAN 》[Blog post]. Retrieved from https://kexue.fm/archives/9662
@online{kexuefm-9662,
title={生成扩散模型漫谈(十九):作为扩散ODE的GAN},
author={苏剑林},
year={2023},
month={Jun},
url={\url{https://kexue.fm/archives/9662}},
}
June 24th, 2023
[...]Read More [...]
June 25th, 2023
我想了想,既然都从轨迹的角度来看待了,能不能用卡尔曼滤波来加速扩散模型的收敛呢,于是我在Reflow的基础上把目标改成$v(x_1, t_1)=v_{ema}(x_2, t_2)+t_2 \cdot (v - v_{ema}(x_2, t_2))$,没想到还真能大幅加速模型收敛的速度,这里就相当于用了最简单的g-h 滤波器,那么理论上GAN也可以这样。
大佬最近我也在看cm,有没有兴趣留个联系方式交流一下
可以加个Discord:amun_alpha
恭喜!惭愧,这个概念已经超出我的理解范围了(捂脸)~
请问大佬说的是RectifiedFlow里面提出的那个Reflow模型吗?这里ema的权重是怎么确定的呢?
ema就只是模型的滑动平均啊,我这里用的0.95的平滑系数
想问下o_glay大佬,为啥用卡尔曼滤波能加速扩散模型的收敛呢?and为啥既然都从轨迹的角度来看待了就可以用卡尔曼滤波嘞?
June 27th, 2023
看了作者的词向量相关的文章,很感兴趣。我也有一个问题,希望作者赐教。
词向量和句向量,除了可以进行相似度以及距离等计算。
我这里关于词和其它词以及(句子、以及多个词和句子的组合)之间的结构化关系,词向量等怎样表达呢?
一般情况下,词向量没有这种功能。你可能需要参考ON-LSTM:https://kexue.fm/archives/6621
June 27th, 2023
awesome!!
June 28th, 2023
[...]上一篇文章《生成扩散模型漫谈(十九):作为扩散ODE的GAN》中,我们介绍了如何将GAN理解为在另一个时间维度上的扩散ODE,简而言之,GAN实际上就是将扩散模型中样本的运动转化为生成器参数的运动!然而,该文章的推导过程依赖于Wasserstein梯度流等相对复杂和独立的内容,没法很好地跟扩散系列前面的文章连接起来,技术上显得有些“断层”。[...]
August 28th, 2023
公式(1)右边的q好像缺了下标t。
还有就是 “鸡生蛋、蛋生鸡”的循环论证 这里的操作是不是能量模型里面挺常见的?
谢谢指出,已经修正。从能量模型的角度来看, 能量模型的似然函数的梯度确实呈现出“鸡生蛋、蛋生鸡”形式,确实算是常见了。