变分自编码器(七):球面上的VAE(vMF-VAE)
By 苏剑林 | 2021-05-17 | 156204位读者 |在《变分自编码器(五):VAE + BN = 更好的VAE》中,我们讲到了NLP中训练VAE时常见的KL散度消失现象,并且提到了通过BN来使得KL散度项有一个正的下界,从而保证KL散度项不会消失。事实上,早在2018年的时候,就有类似思想的工作就被提出了,它们是通过在VAE中改用新的先验分布和后验分布,来使得KL散度项有一个正的下界。
该思路出现在2018年的两篇相近的论文中,分别是《Hyperspherical Variational Auto-Encoders》和《Spherical Latent Spaces for Stable Variational Autoencoders》,它们都是用定义在超球面的von Mises–Fisher(vMF)分布来构建先后验分布。某种程度上来说,该分布比我们常用的高斯分布还更简单和有趣~
KL散度消失 #
我们知道,VAE的训练目标是
L=Ex∼˜p(x)[Ez∼p(z|x)[−logq(x|z)]+KL(p(z|x)‖q(z))]
其中第一项是重构项,第二项是KL散度项,在《变分自编码器(一):原来是这么一回事》中我们就说过,这两项某种意义上是“对抗”的,KL散度项的存在,会加大解码器利用编码信息的难度,如果KL散度项为0,那么说明解码器完全没有利用到编码器的信息。
在NLP中,输入和重构的对象是句子,为了保证效果,解码器一般用自回归模型。然而,自回归模型是非常强大的模型,强大到哪怕没有输入,也能完成训练(退化为无条件语言模型),而刚才我们说了,KL散度项会加大解码器利用编码信息的难度,所以解码器干脆弃之不用,这就出现了KL散度消失现象。
早期比较常见的应对方案是逐渐增加KL项的权重,以引导解码器去利用编码信息。现在比较流行的方案就是通过某些改动,直接让KL散度项有一个正的下界。将先后验分布换为vMF分布,就是这种方案的经典例子之一。
vMF分布 #
vMF分布是定义在d−1维超球面的分布,其样本空间为Sd−1={x|x∈Rd,‖x‖=1},概率密度函数则为
p(x)=e⟨ξ,x⟩Zd,‖ξ‖,Zd,‖ξ‖=∫Sd−1e⟨ξ,x⟩dSd−1
其中ξ∈Rd是预先给定的参数向量。不难想象,这是Sd−1上一个以ξ为中心的分布,归一化因子写成Zd,‖ξ‖的形式,意味着它只依赖于ξ的模长,这是由于各向同性导致的。由于这个特性,vMF分布更常见的记法是设μ=ξ/‖ξ‖,κ=‖ξ‖,Cd,κ=1/Zd,‖ξ‖,从而
p(x)=Cd,κeκ⟨μ,x⟩
这时候⟨μ,x⟩就是μ,x的夹角余弦,所以说,vMF分布实际上就是以余弦相似度为度量的一种分布。由于我们经常用余弦值来度量两个向量的相似度,因此基于vMF分布做出来的模型,通常更能满足我们的这个需求。当κ=0的时候,vMF分布是球面上的均匀分布。
从归一化因子Zd,‖ξ‖的积分形式来看,它实际上也是vMF的母函数,从而vMF的各阶矩也可以通过Zd,‖ξ‖来表达,比如一阶矩为
Ex∼p(x)[x]=∇ξlogZd,‖ξ‖=dlogZd,‖ξ‖d‖ξ‖ξ‖ξ‖
可以看到Ex∼p(x)[x]在方向上跟ξ一致。Zd,‖ξ‖的精确形式可以算出来,但比较复杂,而且很多时候我们也不需要精确知道这个归一化因子,所以这里我们就不算了。
至于参数κ的含义,或许设τ=1/κ我们更好理解,此时p(x)∼e⟨μ,x⟩/τ,熟悉能量模型的同学都知道,这里的τ就是温度参数,如果τ越小(κ越大),那么分布就越集中在μ附近,反之则越分散(越接近球面上的均匀分布)。因此,κ也被形象地称为“凝聚度(concentration)”参数。
从vMF采样 #
对于vMF分布来说,需要解决的第一个难题是如何实现从它里边采样出具体的样本来。尤其是如果我们要将它应用到VAE中,那么这一步是至关重要的。
均匀分布 #
最简单是κ=0的情形,也就是d−1维球面上的均匀分布,因为标准正态分布本来就是各向同性的,其概率密度正比于e−‖x‖2/2只依赖于模长,所以我们只需要从d为标准正态分布中采样一个z,然后让x=z/‖z‖就得到了球面上的均匀采样结果。
特殊方向 #
接着,对于κ>0的情形,我们记x=[x1,x2,⋯,xd],首先考虑一种特殊的情况:μ=[1,0,⋯,0]。事实上,由于各向同性的原因,很多时候我们都只需要考虑这个特殊情况,然后就可以平行地推广到一般情形。
此时概率密度正比于eκx1,然后我们转换到球坐标系:
{x1=cosφ1x2=sinφ1cosφ2x3=sinφ1sinφ2cosφ3⋮xd−1=sinφ1⋯sinφd−2cosφd−1xd=sinφ1⋯sinφd−2sinφd−1
那么(超球坐标的积分变换,请直接参考“维基百科”)
eκx1dSd−1=eκcosφ1sind−2φ1sind−3φ2⋯sinφd−2dφ1dφ2⋯dφd−1=(eκcosφ1sind−2φ1dφ1)(sind−3φ2⋯sinφd−2dφ2⋯dφd−1)=(eκcosφ1sind−2φ1dφ1)dSd−2
这个分解表明,从该vMF分布中采样,等价于先从概率密度正比于eκcosφ1sind−2φ1的分布采样一个φ1,然后从d−2维超球面上均匀采样一个d−1维向量ε=[ε2,ε3,⋯,εd],通过如下方式组合成最终采样结果
x=[cosφ1,ε2sinφ1,ε3sinφ1,⋯,εdsinφ1]
设w=cosϕ1∈[−1,1],那么
|eκcosφ1sind−2φ1dφ1|=|eκw(1−w2)(d−3)/2dw|
所以我们主要研究从概率密度正比于eκw(1−w2)(d−3)/2的分布中采样。
然而,笔者所不理解的是,大多数涉及到vMF分布的论文,都采用了1994年的论文《Simulation of the von mises fisher distribution》提出的基于beta分布的拒绝采样方案,整个采样流程还是颇为复杂的。但现在都2021年了,对于一维分布的采样,居然还需要拒绝采样这么低效的方案?
事实上,对于任意一维分布p(w),设它的累积概率函数为Φ(w),那么w=Φ−1(ε),ε∼U[0,1]就是一个最方便通用的采样方案。可能有读者抗议说“累积概率函数不好算呀”、“它的逆函数更不好算呀”,但是在用代码实现采样的时候,我们压根就不需要知道Φ(w)长啥样,只要直接数值计算就行了,参考实现如下:
import numpy as np
def sample_from_pw(size, kappa, dims, epsilon=1e-7):
x = np.arange(-1 + epsilon, 1, epsilon)
y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
y = np.cumsum(np.exp(y - y.max()))
y = y / y[-1]
return np.interp(np.random.random(size), y, x)
这里的实现中,计算量最大的是变量y
的计算,而一旦计算好之后,可以缓存下来,之后只需要执行最后一步来完成采样,其速度是非常快的。这样再怎么看,也比从beta分布中拒绝采样要简单方便吧。顺便说,实现上这里还用到了一个技巧,即先计算对数值,然后减去最大值,最后才算指数,这样可以防止溢出,哪怕κ成千上万,也可以成功计算。
一般情形 #
现在我们已经实现了从μ=[1,0,⋯,0]的vMF分布中采样了,我们可以将采样结果分解为
x=w×[1,0,⋯,0]⏟参数向量μ+√1−w2×[0,ε2,⋯,εd]⏟与μ正交的d−2维超球面均匀采样
同样由于各向同性的原因,对于一般的μ,采样结果依然具有同样的形式:
x=wμ+√1−w2νw∼eκw(1−w2)(d−3)/2ν∼与μ正交的d−2维超球面均匀分布
对于ν的采样,关键之处是与μ正交,这也不难实现,先从标准正态分布中采样一个d维向量z,然后保留与μ正交的分量并归一化即可:
ν=ε−⟨ε,μ⟩μ‖ε−⟨ε,μ⟩μ‖,ε∼N(0,1d)
vMF-VAE #
至此,我们可谓是已经完成了本篇文章最艰难的部分,剩下的构建vMF-VAE可谓是水到渠成了。vMF-VAE选用球面上的均匀分布(κ=0)作为先验分布q(z),并将后验分布选取为vMF分布:
p(z|x)=Cd,κeκ⟨μ(x),z⟩
简单起见,我们将κ设为超参数(也可以理解为通过人工而不是梯度下降来更新这个参数),这样一来,p(z|x)的唯一参数来源就是μ(x)了。此时我们可以计算KL散度项
∫p(z|x)logp(z|x)q(z)dz=∫Cd,κeκ⟨μ(x),z⟩(κ⟨μ(x),z⟩+logCd,κ−logCd,0)dz=κ⟨μ(x),Ez∼p(z|x)[z]⟩+logCd,κ−logCd,0
前面我们已经讨论过,vMF分布的均值方向跟μ(x)一致,模长则只依赖于d和κ,所以代入上式后我们可以知道KL散度项只依赖于d和κ,当这两个参数被选定之后,那么它就是一个常数(根据KL散度的性质,当κ≠0时,它必然大于0),绝对不会出现KL散度消失现象了。
那么现在就剩下重构项了,我们需要用“重参数(Reparameterization)”来完成采样并保留梯度,在前面我们已经研究了vMF的采样过程,所以也不难实现,综合的流程为:
L=‖x−g(z)‖2z=wμ(x)+√1−w2νw∼eκw(1−w2)(d−3)/2ν=ε−⟨ε,μ⟩μ‖ε−⟨ε,μ⟩μ‖ε∼N(0,1d)
这里的重构loss以MSE为例,如果是句子重构,那么换用交叉熵就好。其中μ(x)就是编码器,而g(z)就是解码器,由于KL散度项为常数,对优化没影响,所以vMF-VAE相比于普通的自编码器,只是多了一项稍微有点复杂的重参数操作(以及人工调整κ)而已,相比基于高斯分布的标准VAE可谓简化了不少了。
此外,从该流程我们也可以看出,除了“简单起见”之外,不将κ设为可训练还有一个主要原因,那就是κ关系到w的采样,而在w的采样过程中要保留κ的梯度是比较困难的。
参考实现 #
vMF-VAE的实现难度主要是重参数部分,也就还是从vMF分布中采样,而关键之处就是w的采样。前面我们已经给出了w的采样的numpy实现,但是在tf中未见类似np.interp
的函数,因此不容易转换为纯tf的实现。当然,如果是torch或者tf2这种动态图框架,直接跟numpy的代码混合使用也无妨,但这里还是想构造一种比较通用的方案。
其实也不难,由于w只是一个一维变量,每步训练只需要用到batch_size
个采样结果,所以我们完全可以事先用numpy函数采样好足够多(几十万)个w存好,然后训练的时候直接从这批采样好的结果随机抽就行了,参考实现如下:
def sampling(mu):
"""vMF分布重参数操作
"""
dims = K.int_shape(mu)[-1]
# 预先计算一批w
epsilon = 1e-7
x = np.arange(-1 + epsilon, 1, epsilon)
y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
y = np.cumsum(np.exp(y - y.max()))
y = y / y[-1]
W = K.constant(np.interp(np.random.random(10**6), y, x))
# 实时采样w
idxs = K.random_uniform(K.shape(mu[:, :1]), 0, 10**6, dtype='int32')
w = K.gather(W, idxs)
# 实时采样z
eps = K.random_normal(K.shape(mu))
nu = eps - K.sum(eps * mu, axis=1, keepdims=True) * mu
nu = K.l2_normalize(nu, axis=-1)
return w * mu + (1 - w**2)**0.5 * nu
一个基于MNIST的完整例子可见:
至于vMF-VAE用于NLP的例子,我们日后有机会再分享。本文主要还是以理论介绍和简单演示为主~
文章小结 #
本文介绍了基于vMF分布的VAE实现,其主要难度在于vMF分布的采样。总的来说,vMF分布建立在余弦相似度度量之上,在某些方面的性质更符合我们的直观认知,将其用于VAE中,能够使得KL散度项为一个常数,从而防止了KL散度消失现象,并且简化了VAE结构。
转载到请包括本文地址:https://kexue.fm/archives/8404
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (May. 17, 2021). 《变分自编码器(七):球面上的VAE(vMF-VAE) 》[Blog post]. Retrieved from https://kexue.fm/archives/8404
@online{kexuefm-8404,
title={变分自编码器(七):球面上的VAE(vMF-VAE)},
author={苏剑林},
year={2021},
month={May},
url={\url{https://kexue.fm/archives/8404}},
}
May 17th, 2021
想知道K的取值影响是怎样的,我用较高的K,生成图像很清晰,而较低的K很模糊。
在正文补充了一下,κ也被称为“凝聚度”,κ越大,说明越接近单点分布,说明重参数带来的噪声越弱,重构自然就越清晰,反之噪声越大,重构也就越难。
换成温度参数就明白了,谢大佬
May 18th, 2021
关于最后的代码实现里面,有些疑惑:K.l2_normalize 是会对整个 batch*dim 的矩阵进行 l2_norm,而不是只对 -1 axis 进行 l2_norm 吧。。。
好像还真是...谢谢指出,已经修正。
(怪不得我说我的结果好像总是差一点,看来不能偷懒)
^_^,关注苏神一年来受益匪浅。最近刚好在做 VAE 相关的工作,这篇文章以及历史上的文章,对我帮助都非常大!
May 19th, 2021
请问一下,KL(vMF(µ + Δµ , κ)||vMF(µ , κ)) 在固定k和d时,是否仍然是一个常数
已经找到了这篇论文《A Note on the Kullback-Leibler Divergence for
the von Mises-Fisher distribution》,发现并不是常数了
不是,其实就是根据vMF分布的期望公式,就可以很容易推出两个不同的非零均值的vMF分布的KL散度,事实上就是两个均值的cos值的函数。
谢谢,想做类似NVAE的事,每次在前面估计的u上加一个Δµ,也就是说需要加一个kl项为k < u1,u2 > 对吗
并不是κ倍,你可以具体算一下,应该是cos值的若干倍,但这个倍数比较复杂。
May 20th, 2021
第5步到第6步能否写详细一些,中间推导过程写一下。 水平太低了没有看明白。
这个是超球坐标的积分变换问题,直接参考 https://en.wikipedia.org/wiki/N-sphere 就好,也不方便罗列推导过程。
June 7th, 2021
那有没有可能会这样:比如数据集中有N张图,训练的结果只是编码得到N个球面上相互远离的μ embedding,在mui中心重采样结果都能得到很好的重建效果,但是对随机采样的任何其他的球面向量,没有很好的泛化能力能力
那肯定有可能啊,任何带隐变量的生成模型都可能存在这个问题~这个可以理解为你选的隐变量维度过大,也可以理解为你的训练数据不足。
不是吧,高斯先验的话KL=0就意味着随机从标准正态分布采样的任何噪声都是没问题的,只不过面临着文中提到的梯度消失的问题
KL=0意味着VAE训练失败,意味着编码器完全失效,通常也意味着任何噪声都不能解码出正常样本。
那GAN也是一种隐变量的方法,充分训练的GAN网路可以认为是拟合了先验分布中的任意一点,是基本不存在训练和推断不一致的问题的
极端情况下,你可以想象z的维度是100维,但是训练图片只有10张。
当然,理想情况下,只要无限地采样训练覆盖充分,并且判别器生成器能理想地工作,那么GAN也能训练成功,结果就是一个降维映射,100维的空间分为了10块区域,每个区域对应一张图片;
然而,理想情况下,VAE也能训练成功,常规VAE可以自己学习方差,理论上编码器可以自己调整方差,并且充分采样的情况下也可以实现一个样本对应一个区域(而不单单是一个)的隐变量,从而覆盖整个分布;vMF-VAE不能自己学习方差,需要自己调κ,只要你调得好,理论上也可以覆盖。
总之,都谈理想,大家都能做到,没必要五十步笑百步。GAN在图像生成方面比VAE好,但也不是好在所谓“充分拟合先验分布”这一点。但事实上,这种“理想”都不容易达到,甚至GAN更容易训练失败。
June 8th, 2021
大佬很厉害,留个爪印~
August 30th, 2021
vMF分布就是高维球面上的指数族分布呀!
September 12th, 2021
想问以下大佬,在您的代码中损失里面为什么没计算KL散度呢。我在S-VAE的源码中看到是有计算散度这一项的,并且我查看了lossKL散度会越来越大,但是重建loss会变小
我没看过S-VAE的具体实现,但本文已经论证了vMF-VAE的KL散度是一个常数。所以如果真如你所说,那么答案只有一个:S-VAE的实现是错误的。
我看代码跑起来后,最后KL散度会收敛到一个值,这是不是意味着就是一个常数呢。感谢您的回答
KL散度自始至终都是一个常数,跟训练步数没关系,跟参数没关系,跟训练数据没关系,只跟你选择的κ和d有关系。
S-VAE的论文里计算了KL散度对κ的偏导,估计它代码实现里也是将κ作为一个可训练参数进行优化,不过前辈这篇文章里(以及其他我看到的代码实现里)将其作为一个超参数所以KL散度在训练中也就是常数了。
特意去看了一下S-VAE的代码,也是通过拒绝采样来采样w的,难道拒绝采样的过程是可求梯度的?这个真不大清楚。如果不能求梯度,那么对κ的导数就是有偏的了~
September 18th, 2021
苏神,公式4的前一部分相等可以给一个简单的推导吗?或者给一个链接学习一下。
这直接就是根据定义算啊
∇ξlogZd,‖ξ‖=∇ξZd,‖ξ‖Zd,‖ξ‖=∇ξ∫Sd−1e⟨ξ,x⟩dSd−1Zd,‖ξ‖=∫Sd−1∇ξe⟨ξ,x⟩dSd−1Zd,‖ξ‖=∫Sd−1e⟨ξ,x⟩xdSd−1Zd,‖ξ‖=∫Sd−1e⟨ξ,x⟩xZd,‖ξ‖dSd−1=∫Sd−1p(x)xdSd−1=Ex∼p(x)[x]
January 5th, 2022
拒绝采样的过程是可求梯度的。S-VAE拒绝采样参考的是这篇文章 《Reparameterization Gradients through Acceptance-Rejection Sampling Algorithms》
谢谢分享参考