Transformer升级之路:3、从Performer到线性Attention
By 苏剑林 | 2021-04-22 | 54263位读者 |看过笔者之前的文章《线性Attention的探索:Attention必须有个Softmax吗?》和《Performer:用随机投影将Attention的复杂度线性化》的读者,可能会觉得本文的标题有点不自然,因为是先有线性Attention然后才有Performer的,它们的关系为“Performer是线性Attention的一种实现,在保证线性复杂度的同时保持了对标准Attention的近似”,所以正常来说是“从线性Attention到Performer”才对。
然而,本文并不是打算梳理线性Attention的发展史,而是打算反过来思考Performer给线性Attention所带来的启示,所以是“从Performer到线性Attention”。
激活函数 #
线性Attention的常见形式是
\begin{equation}Attention(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})_i = \frac{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j)} = \frac{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)\boldsymbol{v}_j}{\sum\limits_{j=1}^n \phi(\boldsymbol{q}_i)^{\top} \varphi(\boldsymbol{k}_j)}\end{equation}
其中$\phi(\cdot)$、$\varphi(\cdot)$是值域非负的激活函数。那么如何选取这个激活函数呢?Performer告诉我们,应该选择指数函数
\begin{equation}\phi(x)=\varphi(x)=e^x\end{equation}
首先,我们来看它跟已有的结果有什么不一样。在《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》给出的选择是:
\begin{equation}\phi(x)=\varphi(x)=1 + \text{elu}(x) = \left\{\begin{aligned}1 + x,\, x \geq 0\\ e^x,\, x < 0\end{aligned}\right.\end{equation}
我们知道$1+x$正是$e^x$在$x=0$处的一阶泰勒展开,因此$1+\text{elu}(x)$这个选择其实已经相当接近$e^x$了。
此外,$\phi(x)=\varphi(x)=e^x$这个方案还跟《Efficient Attention: Attention with Linear Complexities
》一文中引入的双重softmax来构建线性Attention的设计很相似,在那种设计中有$\phi(\boldsymbol{q})=softmax(\boldsymbol{q}),\varphi(\boldsymbol{k})=e^{\boldsymbol{k}}$,相比直接$\phi(x)=\varphi(x)=e^x$只不过归一化的位置有所不同。
简单推导 #
为什么说Performer告诉我们激活函数的最佳选择是$e^x$呢?我们来看Performer找到的将标准Attention线性化的映射:
\begin{equation}\begin{aligned}
e^{\boldsymbol{q}\cdot \boldsymbol{k}}&=\mathbb{E}_{\boldsymbol{\omega}\sim \mathcal{N}(\boldsymbol{\omega};0,\boldsymbol{1}_d)}\left[e^{\boldsymbol{\omega}\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \times e^{\boldsymbol{\omega}\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\right]\\[6pt]
&\approx\underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}-\Vert \boldsymbol{q}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{q}}}
\cdot \underbrace{\frac{1}{\sqrt{m}}\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix}}_{\tilde{\boldsymbol{k}}}
\end{aligned}\end{equation}
简单来说,Performer找到了一个映射,使得$d$维向量$\boldsymbol{q},\boldsymbol{k}$被映射为了$m$维向量$\tilde{\boldsymbol{q}},\tilde{\boldsymbol{k}}$,并且满足近似关系$e^{\boldsymbol{q}\cdot \boldsymbol{k}}\approx \tilde{\boldsymbol{q}}\cdot\tilde{\boldsymbol{k}}$,此时
\begin{equation}a_{i,j} = \frac{e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}{\sum\limits_j e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}\approx \frac{\tilde{\boldsymbol{q}}_i\cdot\tilde{\boldsymbol{k}}_j}{\sum\limits_j \tilde{\boldsymbol{q}}_i\cdot\tilde{\boldsymbol{k}}_j} = \frac{(\lambda(\tilde{\boldsymbol{q}}_i)\tilde{\boldsymbol{q}}_i)\cdot\tilde{\boldsymbol{k}}_j}{\sum\limits_j (\lambda(\tilde{\boldsymbol{q}}_i)\tilde{\boldsymbol{q}}_i)\cdot\tilde{\boldsymbol{k}}_j}\end{equation}
最后一个等式表明,往$\tilde{\boldsymbol{q}}$里边乘以一个常数(哪怕这个常数跟$\tilde{\boldsymbol{q}}$有关),Performer的结果完全不改变,这意味着将映射改为
\begin{equation}
\tilde{\boldsymbol{q}} = \begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix},\qquad
\tilde{\boldsymbol{k}}=\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}-\Vert \boldsymbol{k}\Vert^2 / 2} \end{pmatrix}
\end{equation}
Performer的结果不会有任何变化。当然,这里$\Vert \boldsymbol{k}\Vert^2$这一项还不能去掉,但是如果我们假设$\Vert \boldsymbol{k}\Vert^2$不会波动太大,它并不是Attention的主要因素,那么这一项也相当于一个常数,于是最终的映射(近似地)等价为
\begin{equation}
\tilde{\boldsymbol{q}} = \begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{q}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{q}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{q}} \end{pmatrix},\qquad
\tilde{\boldsymbol{k}}=\begin{pmatrix}e^{\boldsymbol{\omega}_1\cdot \boldsymbol{k}} \\
e^{\boldsymbol{\omega}_2\cdot \boldsymbol{k}}\\
\vdots\\
e^{\boldsymbol{\omega}_m\cdot \boldsymbol{k}} \end{pmatrix}
\end{equation}
这个看上去已经简化很多的映射该怎么理解呢?其实$m$个随机向量$\boldsymbol{\omega}_1,\boldsymbol{\omega}_2,\cdots,\boldsymbol{\omega}_m$拼成了一个$d\times m$的矩阵,它将$d$维的$\boldsymbol{q},\boldsymbol{k}$映射为了$m$维的向量,然后加上激活函数$e^x$得到了$\tilde{\boldsymbol{q}},\tilde{\boldsymbol{k}}$。我们知道Attention的$\boldsymbol{q},\boldsymbol{k}$都有一个全连接层变换,如果我们将这个$d\times m$的映射矩阵整合到全连接层中,那么剩下的就是一个激活函数$e^x$了!
所以这就是最优激活函数$e^x$的来源了,只要我们将$\boldsymbol{q},\boldsymbol{k}$的输出维度从$d$维改为$m$维,然后配合$e^x$的激活函数,那么理论上它就有Performer的拟合能力,甚至更强,因为Performer的$d\times m$矩阵是一个固定的随机矩阵,而这里我们相当于把该矩阵也设为可训练了,还去掉了低秩约束,空间是比Performer更大的。
低秩问题 #
不管是本文的主角Performer,还是之前在《Nyströmformer:基于矩阵分解的线性化Attention方案》介绍的Nyströmformer,它们的思路都是“寻找一个能逼近标准Attention的线性Attention”。那么一个很自然的问题就是:标准Attention有什么好的?哪里值得大家向它对齐?
从信息损失的角度来看,标准Attention矩阵的“秩”可能更大,即更接近可逆矩阵,这意味着它能保留更多有效信息。具体来说,Attention矩阵是一个$n\times n$的矩阵,它由$\boldsymbol{Q},\boldsymbol{K}\in\mathbb{R}^{n\times d}$通过$softmax(\boldsymbol{Q}\boldsymbol{K}^{\top})$而来,要注意的是,这里的$d$是Attention的key_size,比如对于BERT base来说它只是64,而$n$往往比较大,这说明$\boldsymbol{Q}\boldsymbol{K}^{\top}$的秩不超过$d$,而且$d\ll n$,即离满秩还远得很。不过,$softmax$的关键运算是$e^{\boldsymbol{Q}\boldsymbol{K}^{\top}}$,一个矩阵如果每个元素取指数的话,那么新矩阵的秩是可能增加的!所以标准Attention矩阵有升秩的可能性,意味着它蕴含了更有效处理信息的能力。
相比之下,线性Attention矩阵是$\tilde{\boldsymbol{Q}}\tilde{\boldsymbol{K}}^{\top}$的形式,所以线性Attention矩阵的秩一定不超过$m$,而为了弥补秩的损失,所以一般要设置$m > d$,在Performer的实验中选择的是$m = 4d$,也就是key_size扩大为4倍,秩的重要性可见一斑。当然,扩大了key_size,一个直接的后果是处理短序列的时候,线性Attention还比标准Attention要慢,这是线性Attention的固有瓶颈。
关于Attention矩阵的秩的理论分析,也有一些论文可以参考,比如《Low-Rank Bottleneck in Multi-head Attention Models》就指出哪怕在标准Attention中,低秩性也是一个严重的瓶颈,增大key_size可以提升性能;上个月的《Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth》则指出,如果没有残差和FFN,那么标准Attention有极大的风险退化为秩等于1的简单变换。连标准Attention这个有“升秩潜力”的模型都有低秩问题,更不用说线性Attention这种本身秩就有上限的模型了。
所以,一句话就是:用线性Attention需要用更大的key_size来维持矩阵的秩。
集中注意 #
我们还可以从稀疏性角度来理解标准Attention的好处。直观来想,既然是“注意力机制”,那么肯定需要“集中注意力”,如果太分散,那么可能就相当于平均池化了,而“集中注意力”,意味着每个token应该只能显著地关联到若干个token,用数学的话说,那就是意味着Attention矩阵是稀疏的,或者说至少要具备变得稀疏的可能性。
对于标准Attention来说,它通过softmax来归一化
\begin{equation}a_{i,j} = \frac{e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}{\sum\limits_j e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}}\end{equation}
其中指数函数$e^x$起到了一个放大的作用,只要各个$\boldsymbol{q}_i\cdot \boldsymbol{k}_j$本身能拉开一定差距,那么$e^{\boldsymbol{q}_i\cdot \boldsymbol{k}_j}$会进一步放大这种差距,结果就是归一化之后除了最大值的那几个位置之外,剩下的概率都很接近于0了,这说明标准Attention是有潜力“集中注意力”的。而对于线性Attention来说,它是直接内积的结果,没有得到$e^x$的进一步放大,所以它的注意力是比较稠密的,在序列长度较大的时候,它往往就很接近平均池化了。要缓解这一点,还是需要增大key_size,来放大差距,直观来说,就是$n$向量放到一个低维空间太“挤”了,换到更高维的空间就“松”一些了。
怎么样验证稀疏的重要性呢?笔者曾经尝试过,将线性Attention的Attention矩阵先算出来,然后强行截断Attention矩阵(也就是每个token只跟前后几个token做attention,变成局部形式的Attention)让它变得稀疏,结果发现这种截断后的线性Attention效果明显好于全矩阵的线性Attention。这就肯定了稀疏的重要性了,当然,这样把Attention矩阵先算出来然后前行截断的方式,使得线性Attention的复杂度不再是线性的了,因此不具备实用价值,仅用于理论验证。
还有一个实验现象可以辅助证明稀疏的重要性,那就是线性Attention做语言模型或者解码器的时候,效果是跟标准Attention差不了多少的,这时候线性Attention变成了单向的RNN(参考《Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention》),等价于Attention矩阵变成了下三角阵,也是更稀疏了。相比之下,如果用不稀疏的双向的线性Attention直接做MLM模型,则掉点会相当明显。
更重要的是,稀疏性和前一节提到的秩是有密切关联的,甚至可以说它们是“一体两面”:适当的稀疏化方法能提高矩阵的秩!比如做语言模型的下三角Attention矩阵,只要对角线元素非零(往往都能达到),那么这时候的矩阵直接就是满秩可逆阵了!还有笔者实验的局部Attention截断,也能增加矩阵的秩,比如极端情况下,每个token只跟自身做attention,那么Attention矩阵就是满秩的单位阵了!
文章小结 #
本文从Performer出发思考了线性Attention的一些问题,包括关于线性Attention的激活函数选择,以及线性Attention的瓶颈所在(低秩性、稀疏性),总的结论是,线性Attention的最佳激活函数应当是指数函数,而有效的Attention机制应当具备更高的秩和更大的稀疏性。
转载到请包括本文地址:https://kexue.fm/archives/8338
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Apr. 22, 2021). 《Transformer升级之路:3、从Performer到线性Attention 》[Blog post]. Retrieved from https://kexue.fm/archives/8338
@online{kexuefm-8338,
title={Transformer升级之路:3、从Performer到线性Attention},
author={苏剑林},
year={2021},
month={Apr},
url={\url{https://kexue.fm/archives/8338}},
}
April 22nd, 2021
一个矩阵如果每个元素取指数的话,那么新矩阵的秩是可能增加的;
你好,请问下这句话该怎么理解,感谢赐教
比如$\begin{pmatrix}0 & 0 \\ 0 & 0\end{pmatrix}$的秩是0,但是$\begin{pmatrix}e^0 & e^0 \\ e^0 & e^0\end{pmatrix}$的秩是1。
您好,一个矩阵如果每个元素取指数的话,那么新矩阵的秩是可能增加的,但是也应该存在降低矩阵秩的可能吧?比如\begin{pmatrix}0&1&2\\3&4&5\\3&4&5\end{pmatrix}的秩为2,但是每个元素取指数时,\begin{pmatrix}1&e&e^{2}\\e^{3}&e^{4}&e^{5}\\e^{3}&e^{4}&e^{5}\end{pmatrix}的秩应该是1,是不是说明也存在降秩的风险,但是我们还是普遍认为标准Transformer的性能更好,是因为网络能够通过学习避开一些降秩的解吗?关于这点您怎么理解呢?感谢赐教
如果模型需要更高的秩,它自然会学到更高的秩。取exp是提供了这种可能,而不取exp的线性Attention是连这种可能都没有。
苏神, 请教一下, 新矩阵的秩最多能增加多少呢?
从上面的例子来看似乎只能+1? 有没有能增加得更多的例子呢?
感谢赐教.
import numpy as np
n = 100
x = np.random.randn(n)
y = x[:, None].dot(x[None])
np.linalg.matrix_rank(y) # 秩为1
np.linalg.matrix_rank(np.exp(y)) # 秩大概率为17、18
September 8th, 2021
最近的Perceiver,Fastformer都开始将SA压缩 lantent space 中的特定length进行随机投影,再进行下游任务decoder,是否标志着多模态的e2e框架就会是这种形式来发展?
多模态我没研究,所以Perceiver没仔细想过。但是FastFormer我自己做了实验,训练mlm任务并没有收敛。
March 24th, 2023
苏神,如果说我选择了指数函数作为核函数,那这个能跟旋转位置编码混用吗?
感觉exp会破坏掉旋转位置编码的相对位置特性,因为核函数是在Q/K内积前发生作用的。
所以,如果要用rotate,应该顺序是 q -> q~ -> exp -> rotate(k也是这个顺序),然后q * (k * v)这样。
这种情况下,可能rotate会对exp建立起来的高秩特性有所损伤(因为是要对q的每一对位置乘以正余弦因子并求和),但是可能损伤没那么大,而且最多是把损伤建立在一对相邻位置上的;这样基本上能保证高秩+相对位置特性
如果顺序是q -> q~ -> rotate -> exp,那这个rotate基本上就是废了(rotate比较脆弱,必须放在Q*K之前的最后一步操作),exp会破坏掉rotate的相对位置特性
这里有讨论过我的参考做法:https://kexue.fm/archives/8265#%E7%BA%BF%E6%80%A7%E5%9C%BA%E6%99%AF
实际上加在exp之前应该也问题不大,反正performer的意思是能够充分模拟标准attention,那么理论上加在exp之前它也能充分模拟相对位置。
April 27th, 2023
Hi,请问一下,类似文中提到有去实验局部attention这种情况,具体对比的方式是怎样的,只是看收敛的更快or更好,还是整体预训练对齐某版本的bert去某个数据集比较指标?
当时只看预训练的收敛情况吧。
September 27th, 2023
[...]本文来自 苏神博客 :Transformer升级之路:3、从Performer到线性Attention[...]