我在Performer中发现了Transformer-VQ的踪迹
By 苏剑林 | 2023-11-29 | 44761位读者 |前些天我们在《VQ一下Key,Transformer的复杂度就变成线性了》介绍了“Transformer-VQ”,这是通过将Key序列做VQ(Vector Quantize)变换来实现Attention复杂度线性化的方案。诚然,Transformer-VQ提供了标准Attention到线性Attentino的一个非常漂亮的过渡,给人一种“大道至简”的美感,但熟悉VQ的读者应该能感觉到,当编码表大小或者模型参数量进一步增加时,VQ很可能会成为效果提升的瓶颈,因为它通过STE(Straight-Through Estimator)估计的梯度大概率是次优的(FSQ的实验结果也算是提供了一些佐证)。此外,Transformer-VQ为了使训练效率也线性化所做的梯度截断,也可能成为将来的效果瓶颈之一。
为此,笔者花了一些时间思考可以替代掉VQ的线性化思路。从Transformer-VQ的$\exp\left(QC^{\top}\right)$形式中,笔者联想到了Performer,继而“顺藤摸瓜”地发现原来Performer可以视为Soft版的Transformer-VQ。进一步地,笔者尝试类比Performer的推导方法来重新导出Transformer-VQ,为其后的优化提供一些参考结果。
前情回顾 #
首先,让我们花一些时间回顾一下Transformer-VQ。设$Q,K\in\mathbb{R}^{n\times d_k},V\in\mathbb{R}^{n\times d_v}$,Transformer-VQ的关键,是对$K$做了如下VQ近似:
\begin{equation}K\approx\hat{K}\triangleq\Delta C\end{equation}
这里的$\Delta\in\{0,1\}^{n\times c},C\in\mathbb{R}^{c\times d_k}$都是矩阵,其中$C$是可训练的参数,$\Delta$则定义为:
\begin{equation}\Delta_{i,j} = \left\{\begin{aligned}& 1, \quad j=\mathop{\text{argmin}}_{k=1,2,\cdots,c} \Vert K_i - C_k\Vert \\
& 0, \quad\text{其他}\end{aligned}\right.\end{equation}
说白了,VQ就是用与$K_i$最相近的那个$C_j$来近似$K_i$。在这个近似之下,我们有(简单起见,以Encoder为例)
\begin{equation}\exp\left(Q\hat{K}{}^{\top}\right)V = \exp\left(QC^{\top}\Delta^{\top}\right)V = \exp\left(QC^{\top}\right)\Delta^{\top}V = \exp\left(QC^{\top}\right)(\Delta^{\top}V)\label{eq:transformer-vq}\end{equation}
了解线性Attention的读者很容易认出来,最后一个式子的运算就是线性复杂度的,它就是本文的主角之一Transformer-VQ(的分子,还有分母同理)。
没有很复杂的推导,线性Attention就出来了,这就给我们一种感觉,仿佛我们是在对Key做近似的“不经意间”就将Attention的复杂度降为了线性,美感十足。因此,再次回到了我们已经提过多次的评价——Transformer-VQ提供了标准Attention到线性Attentino的一个非常漂亮的过渡。
似曾相识 #
Transformer-VQ的$\exp\left(QC^{\top}\right)$让笔者联想到了之前的文章《Transformer升级之路:3、从Performer到线性Attention》。在那篇文章中,笔者对Performer的结果做了一些简化,然后断言线性Attention的$Q,K$的最佳激活函数是$\exp$,而Transformer-VQ同样出现了$\exp$,所以它们之间也许有着某种相关性。
为了挖掘这种联系,让我们请出Performer,它基于一个漂亮的近似:
\begin{equation}
e^{\boldsymbol{q}\cdot \boldsymbol{k}}=\mathbb{E}_{\boldsymbol{\omega}\sim \mathcal{N}(\boldsymbol{\omega};0,\boldsymbol{1}_d)}\left[e^{\boldsymbol{\omega}\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \,e^{\boldsymbol{\omega}\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\right]\approx\underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}}
\label{eq:performer}\end{equation}
由于最后还要对所有$\boldsymbol{k}$的注意力归一化,所以去掉上式中的$\frac{1}{\sqrt{m}}$、$-\Vert \boldsymbol{q}\Vert^2/2$都不会影响最终结果,同时,如果假设$\boldsymbol{\omega}_1,\boldsymbol{\omega}_2,\cdots,\boldsymbol{\omega}_m$的模长都相等(参考JL引理),那么$\boldsymbol{k}$的指数都减去$\Vert\boldsymbol{\omega}_i\Vert^2/2$也不会影响结果。于是,Performer等价于用以下的格式做$\tilde{\boldsymbol{q}},\tilde{\boldsymbol{k}}$:
\begin{equation}\underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_1\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_2\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2-\Vert \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} = \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\begin{pmatrix}e^{-\Vert \boldsymbol{k}-\boldsymbol{\omega}_1\Vert^2 / 2} \\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_2\Vert^2 / 2}\\
\vdots\\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \propto \underbrace{\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{softmax\begin{pmatrix}e^{-\Vert \boldsymbol{k}-\boldsymbol{\omega}_1\Vert^2 / 2} \\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_2\Vert^2 / 2}\\
\vdots\\
e^{-\Vert \boldsymbol{k} - \boldsymbol{\omega}_m\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}} \end{equation}
对比最后一个式子和$\eqref{eq:transformer-vq}$,就会发现它们有诸多相似之处:$\boldsymbol{\omega}_1,\boldsymbol{\omega}_2,\cdots,\boldsymbol{\omega}_m$不就相当于编码表$C$?$\tilde{\boldsymbol{q}}$不就相当于$\exp\left(QC^{\top}\right)$?至于最后的$\tilde{\boldsymbol{k}}$,它以$-\Vert \boldsymbol{k} - \boldsymbol{\omega}_i\Vert^2 / 2$为logits做softmax,突出的不就是与$\boldsymbol{k}$最相近的那个$\boldsymbol{\omega}_i$?而softmax的极限就是one hot,所以这不正好对应着Transformer-VQ的$\Delta$矩阵?因此,这不能说一模一样,但也有六七分相似了。
依样葫芦 #
当然,上述结果更多的是一种形象的类比而不是等价性,因为Performer本质上基于完全不同的近似思路,比如它里边的$\boldsymbol{\omega}_1,\boldsymbol{\omega}_2,\cdots,\boldsymbol{\omega}_m$是随机采样并固定下来的,这意味它们作为中心向量的近似程度其实是很差的。但这种类似引发了一个思考:能否模仿Performer的思路来重新推导一遍Transformer-VQ呢?即像式$\eqref{eq:performer}$一样,先构造一个精确相等的结果,然后再转化为采样近似来得到线性版本。
经过几天的思考,笔者发现了一种可以构造出期望推导的方案。首先,我们借助狄拉克函数写出
\begin{equation}e^{\boldsymbol{q}\cdot \boldsymbol{k}} = \int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}}\delta(\boldsymbol{\omega} - \boldsymbol{k})d\boldsymbol{\omega}\end{equation}
这是纯粹有狄拉克函数的定义给出的恒等式,还没涉及到任何精巧的运算或者近似。然而,当我们将它代入Attention(的分子)时,出现了一些有意思的结果:
\begin{equation}\sum_j e^{\boldsymbol{q}\cdot \boldsymbol{k}_j} \boldsymbol{v}_j = \sum_j \boldsymbol{v}_j\int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}}\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)d\boldsymbol{\omega} = \int e^{\boldsymbol{q}\cdot \boldsymbol{\omega}} \left[\sum_j \delta(\boldsymbol{\omega} - \boldsymbol{k}_j) \boldsymbol{v}_j\right]d\boldsymbol{\omega}\label{eq:inf-vq}\end{equation}
最后一个等号,不就正好是线性Attention的形式?!当然,由于需要对$\boldsymbol{\omega}$积分,所以上式跟《Transformer升级之路:5、作为无限维的线性Attention》一样,都是“无限维”的线性Attention,暂时只有形式上的价值。
通常来说,我们会将$\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)$理解为正态分布$\mathcal{N}(\boldsymbol{\omega};\boldsymbol{k}_j,\sigma^2\boldsymbol{I})$在$\sigma\to 0$的极限,这也意味着$\delta(\boldsymbol{\omega} - \boldsymbol{k}_j)$具有条件分布$p(\boldsymbol{\omega}|\boldsymbol{k}_j)$的意义。不过,从生成模型的角度来看,狄拉克函数就是单点分布,说白了就是把训练集背下来,所以它没有抽象和泛化能力。为了缓解这一点,我们将$p(\boldsymbol{\omega}|\boldsymbol{k}_j)$用GMM(Gaussian Mixture Model,高斯混合模型)来近似:
\begin{equation}p(\boldsymbol{\omega}|\boldsymbol{k}_j) \approx \sum_{y=1}^m \mathcal{N}(\boldsymbol{\omega};\boldsymbol{c}_y,\sigma^2\boldsymbol{I}) \,p(y|\boldsymbol{k}_j) \end{equation}
代入式$\eqref{eq:inf-vq}$,然后取$\sigma\to 0$的极限,我们就得到
\begin{equation}\sum_j e^{\boldsymbol{q}\cdot \boldsymbol{k}_j} \boldsymbol{v}_j \approx \sum_{y=1}^m e^{\boldsymbol{q}\cdot \boldsymbol{c}_y} \left[\sum_j p(y|\boldsymbol{k}_j) \boldsymbol{v}_j\right]\end{equation}
这就得到一个有限维的线性Attention。如果将$p(y|\boldsymbol{k}_j)$对齐Transformer-VQ的one hot分布$\Delta$的定义,那么得到的结果就是Transformer-VQ的式$\eqref{eq:transformer-vq}$。
文章小结 #
本文介绍了笔者的一个发现:早期的线性Attention工作“Peformer”可以视为一个“Soft”版的Transformer-VQ。然后,在这个观察上进一步得到了Transformer-VQ的一个新推导:利用狄拉克函数将标准Attention转化为无限维线性Attention,然后加上GMM近似就可以得到Transformer-VQ。
转载到请包括本文地址:https://kexue.fm/archives/9862
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Nov. 29, 2023). 《我在Performer中发现了Transformer-VQ的踪迹 》[Blog post]. Retrieved from https://kexue.fm/archives/9862
@online{kexuefm-9862,
title={我在Performer中发现了Transformer-VQ的踪迹},
author={苏剑林},
year={2023},
month={Nov},
url={\url{https://kexue.fm/archives/9862}},
}
November 29th, 2023
[...]Read More [...]
November 30th, 2023
妙啊
January 16th, 2024
其实如果把Q、K都vq,近似误差会小很多
$q\cdot k=(q-\langle q \rangle_i+\langle q \rangle_i)\cdot(k-\langle k \rangle _j+\langle k \rangle_j)$
$=(q-\langle q \rangle_i)\cdot(k-\langle k \rangle_j)+\langle q \rangle_i\cdot\langle k \rangle_j+(q-\langle q \rangle_i)\cdot\langle k \rangle_j+\langle q \rangle_i\cdot(\langle k \rangle-\langle k \rangle_j)$ $\approx \langle q \rangle_i\cdot\langle k \rangle_j+(q-\langle q \rangle_i)\cdot\langle k \rangle_j+\langle q \rangle_i\cdot(k -\langle k \rangle_j)$,
第一项是平方项(Q和K的误差相成),所以扔掉了;其实这就是平均场近似的trick。
相比于只有K进行vq
$q\cdot k=q\cdot(k-\langle k \rangle _j+\langle k \rangle_j)\approx q\cdot(k-\langle k \rangle _j)$
误差是正比于K的误差。
计算量确实比transformer-vq计算attention matrix直接翻倍,但是总归还是线性的,而且一样可以在训练完成进行后处理来近似…
只对K进行vq的第二个式子是不是写错了?
另外,你写的$\langle q \rangle_i\cdot(k -\langle k \rangle_j)$这一项,貌似不是线性的?线性依赖于对K进行vq,这一项关于K不是vq的。
抱歉确实存在type error,只对K进行vq的公式最后等号那应该是$q\cdot \langle k \rangle_j$ ,而 $q \cdot (k - \langle k \rangle_j)$ 是被忽略掉的一阶小量。
对Q和K都进行vq的公式没错,可以看到只有一个二阶小量($(q - \langle q \rangle_i) \cdot (k - \langle k \rangle_j)$)被忽略了,而且公式最后一个等号中,第一项$\langle q \rangle_i\langle k \rangle_j$ 是可以在推理前提前计算的,另外两项都是线性的,所以乘法计算量刚好是只对K进行vq的两倍。
因为对Q、K进行vq没法和positon embbeding兼容,这个想法在我脑子里待了一年多了卡住了,从这篇点进去您之前的博文中提到可以从rope显式的改成相对位置编码,比较激动……
我想通过这种办法尽可能减小vq带来的误差,应该能最大限度减少性能损失吧
是线性的
$k$ 不是vq的,就像只对K进行vq时,$q \cdot \langle k \rangle_j$ 中 $q $不是vq的一样,所有的公式中都不存在两个没vq过的量相乘,最多只有一阶,所以是线性的;最极端的例子,把所有一阶项都扔掉,那就是0阶了。
当然,vq本身的计算量也是需要考虑的,差不多也是 $L \cdot V \cdot D$次乘法; 所以如果Q、K的词表大小一样,还是用两倍乘法,去让vq的误差减小到二阶
也可以这样看,改写一下就是$q\cdot \langle k \rangle _j + k \cdot \langle q \rangle _ i - \langle q \rangle _ i\cdot \langle k \rangle _ j$ ,平均场近似时经常看到类似的公式
January 22nd, 2024
我明白了,内存不是线性的,我过度关注计算了…
如果K没有vq,softmax的分母也不是线性的;看来K的vq才是关键,不能出现K没被vq的项…
是这个意思@wayne_chiu|comment-23543
从数学分析的角度看属于线性近似,但算Attention的时候不是线性的复杂度~
April 28th, 2024
感谢您的分享,您的想法和推导非常有意思!我们也发现了比较相关的性质,整理到了之前的文章(Linear Complexity Randomized Self-attention Mechanism:https://arxiv.org/pdf/2204.04667)和博客(https://hkunlp.github.io/blog/2022/lara/)里。
我们推导出Performer等价于一个full attention的importance sampling estimator,其中proposal是standard Gaussian。在这个框架下proposal可以推广到任意其他的分布,甚至还可以使用多个proposal去估计full attention。
事实上,如果使用C个Gaussian proposal,把每个分布的mean设为每个code,就会得到和Transformer-VQ相似的形式。我们在博客的https://hkunlp.github.io/blog/2022/lara/#a-unified-view-of-lara-ra-and-rfa;以及文章里的Section 4.3给出了具体的形式:
\begin{align} \mathsf{Ours}\left(q_{n},K,V\right) &= \frac{\sum_{c=1}^C \alpha'_{nc}(\omega_c) \xi(q_n,\omega_c)^\top \sum_{m=1}^M\xi(k_m, \omega_c) v_{m}^{\top}}{\sum_{c=1}^C \alpha'_{nc}(\omega_c) \xi(q_n,\omega_c)^\top \sum_{m=1}^M \xi(k_{m}, \omega_c)}, &&\omega_c \sim q_c(\omega)\notag\\ \mathsf{Performer}\left(q_{n},K,V\right) &= \frac{ \sum_{s=1}^S\xi(q_n,\omega_s)^{\top}\sum_{m=1}^M\xi(k_m, \omega_s)v_{m}^{\top}}{\sum_{s=1}^S \xi(q_n,\omega_s)^{\top}\sum_{m'=1}^M\xi(k_{m'}, \omega_s)}, &&\omega_1,\dots,\omega_S \sim \mathcal{N}(\omega;0, \mathbf{I})\notag \end{align}
其中$\xi(x,\omega) = \exp{\left(\omega^\top x - \frac{1}{2}||x||^2\right)}$,$\alpha'_{nc}(\omega_c) = \alpha_{nc}(\omega_c)\mathcal{N}(\omega_c;0, \mathbf{I})/q_c(\omega_c)$, $\alpha_{nc}$表示任意的系数(只要$\sum_{c=1}^C \alpha_{nc} = 1$)。形式上可以看出来是soft VQ,也即您提到的GMM。
感觉linear attention还有很多可以挖掘的地方,可以从许多不同的视角去理解。非常感谢您分享的见解,十分且一如既往地深刻有趣:)
感谢大佬莅临指导。我理解你这里的结果,是不是有点类似于本文最后的基于狄拉克函数和高斯混合模型的推导结果?
Liner Attention我觉得太形式化了,本质上还是没有翻书能力的RNN,我目前倾向于认为单纯的潜力有限,需要想办法补上翻书能力。
感谢回复(●'◡'●)!是的,跟您的狄拉克函数和高斯混合模型推导非常相似,主要目的都是为了研究构建feature map的样本分布该怎么选。
十分同意您对linear attention的看法。把所有$\phi(q)^\top\phi(k)$相加(而非像full attention一样显性地存储每个$\exp(q^\top k)$)会模糊很多上下文的信息,降低了整个序列的分辨率,如何重建上下文信息让模型可以翻书确实是个大问题。多谢您的讨论!