变分自编码器(五):VAE + BN = 更好的VAE
By 苏剑林 | 2020-05-06 | 210462位读者 |本文我们继续之前的变分自编码器系列,分析一下如何防止NLP中的VAE模型出现“KL散度消失(KL Vanishing)”现象。本文受到参考文献是ACL 2020的论文《A Batch Normalized Inference Network Keeps the KL Vanishing Away》的启发,并自行做了进一步的完善。
值得一提的是,本文最后得到的方案还是颇为简洁的——只需往编码输出加入BN(Batch Normalization),然后加个简单的scale——但确实很有效,因此值得正在研究相关问题的读者一试。同时,相关结论也适用于一般的VAE模型(包括CV的),如果按照笔者的看法,它甚至可以作为VAE模型的“标配”。
最后,要提醒读者这算是一篇VAE的进阶论文,所以请读者对VAE有一定了解后再来阅读本文。
VAE简单回顾 #
这里我们简单回顾一下VAE模型,并且讨论一下VAE在NLP中所遇到的困难。关于VAE的更详细介绍,请读者参考笔者的旧作《变分自编码器(一):原来是这么一回事》、《变分自编码器(二):从贝叶斯观点出发》等。
VAE的训练流程 #
VAE的训练流程大概可以图示为
写成公式就是
$$\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]
\end{equation}$$
其中第一项就是重构项,$\mathbb{E}_{z\sim p(z|x)}$是通过重参数来实现;第二项则称为KL散度项,这是它跟普通自编码器的显式差别,如果没有这一项,那么基本上退化为常规的AE。更详细的符号含义可以参考《变分自编码器(二):从贝叶斯观点出发》。
NLP中的VAE #
在NLP中,句子被编码为离散的整数ID,所以$q(x|z)$是一个离散型分布,可以用万能的“条件语言模型”来实现,因此理论上$q(x|z)$可以精确地拟合生成分布,问题就出在$q(x|z)$太强了,训练时重参数操作会来噪声,噪声一大,$z$的利用就变得困难起来,所以它干脆不要$z$了,退化为无条件语言模型(依然很强),$KL(p(z|x)\Vert q(z))$则随之下降到0,这就出现了KL散度消失现象。
这种情况下的VAE模型并没有什么价值:KL散度为0说明编码器输出的是常数向量,而解码器则是一个普通的语言模型。而我们使用VAE通常来说是看中了它无监督构建编码向量的能力,所以要应用VAE的话还是得解决KL散度消失问题。事实上从2016开始,有不少工作在做这个问题,相应地也提出了很多方案,比如退火策略、更换先验分布等,读者Google一下“KL Vanishing”就可以找到很多文献了,这里不一一溯源。
BN的巧与妙 #
本文的方案则是直接针对KL散度项入手,简单有效而且没什么超参数。其思想很简单:
KL散度消失不就是KL散度项变成0吗?我调整一下编码器输出,让KL散度有一个大于零的下界,这样它不就肯定不会消失了吗?
这个简单的思想的直接结果就是:在$\mu$后面加入BN层,如图
推导过程简述 #
为什么会跟BN联系起来呢?我们来看KL散度项的形式:
\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] = \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\Big(\mu_{i,j}^2 + \sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1\Big)\end{equation}
上式是采样了$b$个样本进行计算的结果,而编码向量的维度则是$d$维。由于我们总是有$e^x \geq x + 1$,所以$\sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1 \geq 0$,因此
\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] \geq \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\mu_{i,j}^2 = \frac{1}{2}\sum_{j=1}^d \left(\frac{1}{b} \sum_{i=1}^b \mu_{i,j}^2\right)\label{eq:kl}\end{equation}
留意到括号里边的量,其实它就是$\mu$在batch内的二阶矩,如果我们往$\mu$加入BN层,那么大体上可以保证$\mu$的均值为$\beta$,方差为$\gamma^2$($\beta,\gamma$是BN里边的可训练参数),这时候
\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] \geq \frac{d}{2}\left(\beta^2 + \gamma^2\right)\label{eq:kl-lb}\end{equation}
所以只要控制好$\beta,\gamma$(主要是固定$\gamma$为某个常数),就可以让KL散度项有个正的下界,因此就不会出现KL散度消失现象了。这样一来,KL散度消失现象跟BN就被巧妙地联系起来了,通过BN来“杜绝”了KL散度消失的可能性。
为什么不是LN? #
善于推导的读者可能会想到,按照上述思路,如果只是为了让KL散度项有个正的下界,其实LN(Layer Normalization)也可以,也就是在式$\eqref{eq:kl}$中按$j$那一维归一化。
那为什么用BN而不是LN呢?
这个问题的答案也是BN的巧妙之处。直观来理解,KL散度消失是因为$z\sim p(z|x)$的噪声比较大,解码器无法很好地辨别出$z$中的非噪声成分,所以干脆弃之不用;而当给$\mu(x)$加上BN后,相当于适当地拉开了不同样本的$z$的距离,使得哪怕$z$带了噪声,区分起来也容易一些,所以这时候解码器乐意用$z$的信息,因此能缓解这个问题;相比之下,LN是在样本内进的行归一化,没有拉开样本间差距的作用,所以LN的效果不会有BN那么好。
进一步的结果 #
事实上,原论文的推导到上面基本上就结束了,剩下的都是实验部分,包括通过实验来确定$\gamma$的值。然而,笔者认为目前为止的结论还有一些美中不足的地方,比如没有提供关于加入BN的更深刻理解,倒更像是一个工程的技巧,又比如只是$\mu(x)$加上了BN,$\sigma(x)$没有加上,未免有些不对称之感。
经过笔者的推导,发现上面的结论可以进一步完善。
联系到先验分布 #
对于VAE来说,它希望训练好后的模型的隐变量分布为先验分布$q(z)=\mathcal{N}(z;0,1)$,而后验分布则是$p(z|x)=\mathcal{N}(z; \mu(x),\sigma^2(x))$,所以VAE希望下式成立:
\begin{equation}q(z) = \int \tilde{p}(x)p(z|x)dx=\int \tilde{p}(x)\mathcal{N}(z; \mu(x),\sigma^2(x))dx\end{equation}
两边乘以$z$,并对$z$积分,得到
\begin{equation}0 = \int \tilde{p}(x)\mu(x)dx=\mathbb{E}_{x\sim \tilde{p}(x)}[\mu(x)]\end{equation}
两边乘以$z^2$,并对$z$积分,得到
\begin{equation}1 = \int \tilde{p}(x)\left[\mu^2(x) + \sigma^2(x)\right]dx = \mathbb{E}_{x\sim \tilde{p}(x)}\left[\mu^2(x)\right] + \mathbb{E}_{x\sim \tilde{p}(x)}\left[\sigma^2(x)\right]\end{equation}
如果往$\mu(x),\sigma(x)$都加入BN,那么我们就有
\begin{equation}\begin{aligned}
&0 = \mathbb{E}_{x\sim \tilde{p}(x)}[\mu(x)] = \beta_{\mu}\\
&1 = \mathbb{E}_{x\sim \tilde{p}(x)}\left[\mu^2(x)\right] + \mathbb{E}_{x\sim \tilde{p}(x)}\left[\sigma^2(x)\right] = \beta_{\mu}^2 + \gamma_{\mu}^2 + \beta_{\sigma}^2 + \gamma_{\sigma}^2
\end{aligned}\end{equation}
所以现在我们知道$\beta_{\mu}$一定是0,而如果我们也固定$\beta_{\sigma}=0$,那么我们就有约束关系:
\begin{equation}1 = \gamma_{\mu}^2 + \gamma_{\sigma}^2\label{eq:gamma2}\end{equation}
参考的实现方案 #
经过这样的推导,我们发现可以往$\mu(x),\sigma(x)$都加入BN,并且可以固定$\beta_{\mu}=\beta_{\sigma}=0$,但此时需要满足约束$\eqref{eq:gamma2}$。要注意的是,这部分讨论还仅仅是对VAE的一般分析,并没有涉及到KL散度消失问题,哪怕这些条件都满足了,也无法保证KL项不趋于0。结合式$\eqref{eq:kl-lb}$我们可以知道,保证KL散度不消失的关键是确保$\gamma_{\mu} > 0$,所以,笔者提出的最终策略是:
\begin{equation}\begin{aligned}
&\beta_{\mu}=\beta_{\sigma}=0\\
&\gamma_{\mu} = \sqrt{\tau + (1-\tau)\cdot\text{sigmoid}(\theta)}\\
&\gamma_{\sigma} = \sqrt{(1-\tau)\cdot\text{sigmoid}(-\theta)}
\end{aligned}\end{equation}
其中$\tau\in(0,1)$是一个常数,笔者在自己的实验中取了$\tau=0.5$,而$\theta$是可训练参数,上式利用了恒等式$\text{sigmoid}(-\theta) = 1-\text{sigmoid}(\theta)$。
关键代码参考(Keras):
class Scaler(Layer):
"""特殊的scale层
"""
def __init__(self, tau=0.5, **kwargs):
super(Scaler, self).__init__(**kwargs)
self.tau = tau
def build(self, input_shape):
super(Scaler, self).build(input_shape)
self.scale = self.add_weight(
name='scale', shape=(input_shape[-1],), initializer='zeros'
)
def call(self, inputs, mode='positive'):
if mode == 'positive':
scale = self.tau + (1 - self.tau) * K.sigmoid(self.scale)
else:
scale = (1 - self.tau) * K.sigmoid(-self.scale)
return inputs * K.sqrt(scale)
def get_config(self):
config = {'tau': self.tau}
base_config = super(Scaler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def sampling(inputs):
"""重参数采样
"""
z_mean, z_std = inputs
noise = K.random_normal(shape=K.shape(z_mean))
return z_mean + z_std * noise
e_outputs # 假设e_outputs是编码器的输出向量
scaler = Scaler()
z_mean = Dense(hidden_dims)(e_outputs)
z_mean = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_mean)
z_mean = scaler(z_mean, mode='positive')
z_std = Dense(hidden_dims)(e_outputs)
z_std = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_std)
z_std = scaler(z_std, mode='negative')
z = Lambda(sampling, name='Sampling')([z_mean, z_std])
文章内容小结 #
本文简单分析了VAE在NLP中的KL散度消失现象,并介绍了通过BN层来防止KL散度消失、稳定训练流程的方法。这是一种简洁有效的方案,不单单是原论文,笔者私下也做了简单的实验,结果确实也表明了它的有效性,值得各位读者试用。因为其推导具有一般性,所以甚至任意场景(比如CV)中的VAE模型都可以尝试一下。
转载到请包括本文地址:https://kexue.fm/archives/7381
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (May. 06, 2020). 《 变分自编码器(五):VAE + BN = 更好的VAE 》[Blog post]. Retrieved from https://kexue.fm/archives/7381
@online{kexuefm-7381,
title={ 变分自编码器(五):VAE + BN = 更好的VAE},
author={苏剑林},
year={2020},
month={May},
url={\url{https://kexue.fm/archives/7381}},
}
April 8th, 2021
苏神好,想请教您,从encoder采样得到的latent转为hidden之后,应该作为decoder的q还是k、v?另外就是encoder-transformer的输出应该用cls-pooling还是average-pooling来获得encoder的输出编码呢?
感谢回复。另:之前参考过您写诗的vae(github:bojone/vae/vae_shi.py),不知道您transformer版本的vae是否有在github上分享?
1、如果是Transformer结构,encoder的latent我是作为decoder的Conditional Layer Normalization的condition输入的;
2、cls-pooling还是average-pooling你随意,差别不大;
3、暂时没有
April 15th, 2021
请问 AMS-setcounter.js 我没有在 MathJax2.7.4中找到;这个文件是在哪里的?关于贵站mathjax的配置,有点疑惑……
我自己写的。
May 6th, 2021
不好意思,请教苏神一个简单的问题,p~(x)表示的是x的先验分布吗?和p(x)的差别是?谢谢苏神啦~
p~(x)指的是p上面有个波浪线
$\tilde{p}(x)$是数据分布,也就是训练集。本文好像没出现$p(x)$?
好的谢谢苏神呢:)
November 30th, 2021
博主您好,感谢您的分享。我有一个疑问是: VAE的后验坍塌和KL散度消失是什么样的关系?KL散度消失一定是后验坍塌导致的?
广义上来说,VAE的后验坍塌指的是后验分布完全跟输入$x$没关系,而KL散度消失指的是后验分布跟先验分布一致(也就是均值为0、方差为1),后者是前者的一个特例;
狭义上来说,两者就是一样的概念。
December 1st, 2021
苏神你好,我想生成表格型数据,GAN我也试过感觉效果不太行,VAE分布挺像的,但是还是误差稍大,因为原始数据的波动较小,所以问一下苏神有什么推荐的算法可以试试吗?
什么是表格型数据?看你的说法,似乎是VAE和GAN都行,但就是不大准?这种情况下就是一个调参问题,别人无法帮助,你就算换一个生成模型还是一样要自己调参。
July 28th, 2022
苏神好,能否请你详细解释下"第二项则称为KL散度项,这是它跟普通自编码器的显式差别,如果没有这一项,那么基本上退化为常规的AE"这句话的含义呢? 经过之前的vae阅读,我知道encoder的作用是计算隐变量$z$的均值和方差的,如果KL散度为0话,那代表了无论给什么样的$x$,计算出的均值都会是0,方差都是1。 这就达不到vae'给每个不同的$x_i$计算对应的$μ_i$和$σ_i$的作用了'。 但你说他基本退化为常规AE,在vae的模型下,encoder仅输出$μ$和$σ$(应该是这样吧),如果均值始终为0和方差始终为1的话,那对于不同的$x_i$就体现不出差别了,相当于encoder失效了,感觉这样模型就没用了; 常规AE的encoder应该还能根据不同的$x$输出不同的特征用于解码吧? 不知道我理解的有没有错误。
此外对于这一句"问题就出在$q(x|z)$太强了,训练时重参数操作会来噪声,噪声一大,$z$的利用就变得困难起来,所以它干脆不要$z$了"也不是很理解。 decoder太强为什么就会导致后验$p(z|x)$等于先验$p(z)$呢? 按道理说先验的方差为1不是会使噪声更大么,反而是如果通过训练使得方差趋近于0,就能让噪声变小啊?
没有KL散度项,不是说KL散度等于0呀,而是说的是KL散度的权重为0,即
$$L = L_{\text{重构}} + 0\times L_{\text{KL}}$$
所以,刚刚相反,KL散度本身可以任意大,对应于$z$的方差为0,从而退化为AE。
第一个问题理解了,但对于第二个问题还是不太懂
你说的第二个问题是“decoder太强为什么就会导致后验$p(z|x)$等于先验$p(z)$”?这不就是一个问题么?
首先,你要搞清楚,训练的目标是损失函数最小化,不是凭空想象出来的噪声最小还是最大,我们要关心损失函数的情况对应的结果。如果decoder很强,意味着decoder不需要条件输入,都可以将重构损失降得很小。既然如此,模型干脆可以将KL散度项优化到等于0,这就意味着$p(z|x)=q(z)$。
September 5th, 2022
苏老师您好,我想问一下我自己在训练VAE时batch为100,隐变量维数为64,得到的kl值为0.002这样,我想问这个KLloss算是后验坍塌了吗,重构损失也挺大的,出来的结果也不好,我不知道是网络结构的问题还是后验坍塌了,希望苏老师能解答一下
看上去是算的。
September 6th, 2022
感谢苏神!我用了kl_warmup调了很多次warm_up的速度都无法解决这个问题,我想试试其他的预防kl散度消失的算法,另外,我想请问一下苏神,vae用神经网络去近似分布,需要的数据量会比普通的神经网络多吗,我想生成的是1维的时序数据,我用的网络结构是nvae中的网络结构,隐变量层数是3层,编码器和解码器就是nvae中的残差网络,我想生成1024维的时序数据,如果用800个数据量能达到好的生成效果吗
800个估计没什么好效果吧,别忘了VAE本质上是估计分布,800个数据点就想确定高维空间的分布了?
好的,谢谢苏神!
September 21st, 2022
请问苏神,后验消失是指KL值直接为0,还是接近0也可以认为是后验消失?
过于接近于0也是消失。
September 21st, 2022
苏神您好,我想请问一下hierarchical的隐变量除了第一个group的隐变量先验是标准高斯分布的可以用上面的bn方法,但是其他group的隐变量先验是未知的,是不是无法使用您提出的那种bn的方法
顶多是“均值方差未知”的正态分布吧,不是同样可以转为标准正态分布么?
苏神,你好,关于多层隐变量我有两个问题想问一下:
1、您说的可以转成标准正态分布是直接在式6和式7中减去均值和除以方差来让等式左边变为0和1吗,但是两个普通高斯分布的kl散度和其中一个分布是标准正态分布的kl散度的式子不一样,我求了之后发现在式4中会多出来一项
2、我对第一层先验是标准高斯的隐变量用了您提出的BN的方法,在其他层先验未知的隐变量用了论文中手动调节gamma的方法,我发现随着训练,其他层的kl散度会变得越来越大,我不知道该怎么解释,另外,如果只用BN来解决kl vanishing的话,还需要在kl散度前加上小于1的权重来保持和重构误差一样的数量级吗,或者还需要加上kl_warmup吗?
我理解你的hierarchical,应该是指从第二层开始,每一层的先验的均值方差依赖于前面若干层的结果吧。
根据这里 https://kexue.fm/archives/8512 结果,考虑协方差矩阵是对角阵的情形,有
$$KL(p(\boldsymbol{x})\Vert q(\boldsymbol{x}))=\frac{1}{2}\text{Tr}\left[(\boldsymbol{\mu}_p-\boldsymbol{\mu}_q)^2\boldsymbol{\sigma}_q^{-2} + \boldsymbol{\sigma}_q^{-2}\boldsymbol{\sigma}_p^2 -\log \boldsymbol{\sigma}_q^{-2}\boldsymbol{\sigma}_p^2 - 1\right] $$
这里的运算都是element-wise的。
如果$\boldsymbol{\sigma}_q$不是全1,那么如果想起到类似本文的效果,那么应该对$(\boldsymbol{\mu}_p-\boldsymbol{\mu}_q)\boldsymbol{\sigma}_q^{-1}$来加BN,也就是说新建一个子网络,然后加上BN,然后乘上$\boldsymbol{\sigma}_q$再加上$\boldsymbol{\mu}_q$作为$\boldsymbol{\mu}_p$,而不是直接用子网络去算$\boldsymbol{\mu}_p$。
谢谢苏神抽空回复我,我理解苏神的意思了,是不是还是只能用第一种调节gamma超参数的方法,我试了也发现第二种方法无法实现
哦哦,你是这个意思。我分析了一下,发现这确实是一个有意思的问题,我尝试一下有没有实现的可行性再来讨论哈
好的,谢谢苏神!我也尝试一下