线性Transformer应该不是你要等的那个模型
By 苏剑林 | 2021-08-09 | 102047位读者 |在本博客中,我们已经多次讨论过线性Attention的相关内容。介绍线性Attention的逻辑大体上都是:标准Attention具有$\mathcal{O}(n^2)$的平方复杂度,是其主要的“硬伤”之一,于是我们$\mathcal{O}(n)$复杂度的改进模型,也就是线性Attention。有些读者看到线性Attention的介绍后,就一直很期待我们发布基于线性Attention的预训练模型,以缓解他们被BERT的算力消耗所折腾的“死去活来”之苦。
然而,本文要说的是:抱有这种念头的读者可能要失望了,标准Attention到线性Attention的转换应该远远达不到你的预期,而BERT那么慢的原因也并不是因为标准Attention的平方复杂度。
BERT之反思 #
按照直观理解,平方复杂度换成线性复杂度不应该要“突飞猛进”才对嘛?怎么反而“远远达不到预期”?出现这个疑惑的主要原因,是我们一直以来都没有仔细评估一下常规的Transformer模型(如BERT)的整体计算量。
很多读者都已经知道,Transformer的结构大体上是Embedding层加若干个Transformer层,Embedding层的计算量很少,我们主要关心Transformer层。忽略残差、Layer Normalization等计算量比较小的层不计,每个Transformer层主要组成就是两个子层:Self Attention(简称SA)和FeedForward Network(简称FFN)。虽然Transformer的开山之作声称“Attention is all you need”,但是也有不少工作论证了残差、FFN等模块的必要性了,比如《Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth》。
现在问大家一个问题:
你觉得是SA计算量大还是FFN计算量大?
评估计算量 #
毋庸置疑,SA的复杂度是$\mathcal{O}(n^2)$,而FFN的复杂度则是$\mathcal{O}(n)$,如果你直接凭此就想当然地说SA计算量比FFN大,那就错了!
我们知道加法比乘法快很多,所以在估计计算量的时候我们主要计算要做多少次乘法,神经网络里边,主要的运算是矩阵相乘,不难估计按照定义一个$a\times b$的矩阵乘以一个$b\times c$的矩阵要做$abc$次乘法,所以$abc$就是两个矩阵相乘的复杂度了,这是我们估算Transformer复杂度的依据。
设$n$为序列长度,$d$为head_size(base版是64),$h$为head的数目(base版是12),那么$hd$就是我们通常说的“hidden_size”(base版是768)。对于SA来说,一开始是$Q,K,V$的投影变换,即$n\times hd$的矩阵乘以$hd\times hd$的矩阵做3次,因此计算量是$3n(hd)^2$;然后是$h$个Attention头的运算,每个头先是$n\times d$的$Q$与$d\times n$的$K^{\top}$相乘得到$n\times n$的Attention矩阵(softmax和归一化的计算量暂且忽略),然后$n\times n$的矩阵与$n\times d$的$V$相乘得到$n\times d$的矩阵,这两步的计算量都是$n^2 d$,所以总计算量是$h(n^2 d + n^2 d)$;最后的输出还有一个投影变换,也是$n\times hd$的矩阵乘以$hd\times hd$的矩阵,计算量是$n(hd)^2$。所以,SA的总计算量是
\begin{equation}3n(hd)^2 + h(n^2 d + n^2 d) + n(hd)^2 = 4nh^2 d^2 + 2n^2 hd\end{equation}
至于FFN就比较简单了,它就是两个全连接层,也就是两个矩阵变换(激活函数的计算量也忽略不计),一般的参数设置是:第一层是$n\times hd$的矩阵乘以$hd\times 4hd$的矩阵,第二层就是$n\times 4hd$的矩阵乘以$4hd\times hd$的矩阵。所以总计算量是
\begin{equation}n\times hd\times 4hd + n\times 4hd\times hd = 8nh^2 d^2\end{equation}
这样一来,如果SA的计算量比FFN大,就意味着
\begin{equation}4nh^2 d^2 + 2n^2 hd > 8nh^2 d^2\quad\Leftrightarrow\quad n > 2hd\end{equation}
对于base版来说,这意味着$n > 1536$!也就是说,只有当序列长度超过1536时,SA的计算量才大于FFN,在这之前,都是线性复杂度的FFN占主导!
这还不止,由上面的结果我们可以得到Transformer层总的计算量为
\begin{equation}4nh^2 d^2 + 2n^2 hd + 8nh^2 d^2 = 12nh^2 d^2 + 2n^2 hd\end{equation}
它是关于$n$的一次项和二次项的求和,当$n$足够大时,复杂度自然是$\mathcal{O}(n^2)$,然而二次项占主导的条件是
\begin{equation}2n^2 hd > 12nh^2 d^2\quad\Leftrightarrow\quad n > 6hd\end{equation}
对于base版来说,这意味着$n > 4608$!也就是说,当序列长度接近5000时,Transformer的复杂度才真正体现出二次性!
综合的结论 #
综合上述结果,我们可以得到结论:对于base版来说,当序列长度不超过1536时,Transformer的复杂度都是近乎线性的;当序列长度超过1536时,Transformer的计算量逐渐以Attention为主,复杂度慢慢趋于二次方,直到长度超过4608,才真正以二次项为主。当然这个边界只是一个估计,实际情况可能有所偏差,大家就此感知一下范围和数量级就好。
笔者以前也建议过很多读者,对于不超过2000长度的“长文本”任务,直接用NEZHA或者RoFormer这种不限长度的模型试试,不要想太多的技巧,原因也是如此。你想再多技巧,也顶多是降到线性复杂度,而在这个长度范围内模型本身就是近乎线性的,各种技巧也省不了多少。
对于老老实实用BERT base的读者来说,maxlen一般不超过512,远低于上述界限,因此就不要再说Attention的平方复杂度费硬件之类的吐槽了,因为事实是:
BERT之所以慢,主要是因为它真的大,而不是因为Attention的平方复杂度。
“线性”含义 #
至于对线性Attention“远远达不到预期”而感到疑惑的另一个原因,则是没有从实际情况分析线性Attention的计算量,以至于对线性Attention期待过高。
线性Attention的介绍可以参考《线性Attention的探索:Attention必须有个Softmax吗?》,这里不做重复。简单来说,线性Attention就是按照$Q(K^{\top} V)$的顺序算注意力。所以按照前面的估算方法,线性Attention每个头运算的计算量就是$2nd^2$,而标准Attention则是$2n^2 d$,因此如果$n > d$,那么线性Attention是比标准Attention要省计算量的。(注:实现线性效率的Attention也不止这一种思路,但总的而言复杂度是相似的,因此下面的结论也有代表性。)
对于base版来说,那就是$n > 64$,这个界还是很容易达到的,所以有些读者可能会想“能省一点是一点”、“不用白不用”。然而,这是假设了标准Attention与线性Attention都用同一个$d$的前提下得出的结果。而认真琢磨过《Performer:用随机投影将Attention的复杂度线性化》、《Transformer升级之路:3、从Performer到线性Attention》的读者都知道,线性Attention有着比标准Attention更严重的“低秩瓶颈”,所以如果切换为线性Attention后还用同一个$d$,那么线性Attention的效果将会明显下降,而如果要保留大致相同的效果,那么线性Attention要用更大的$d$(一般是原来的4倍左右)。
这样一来,线性Attention的计算量应该是$2n(4d)^2$,如果线性Attention要比标准Attention快,那么就要$n > 16d$,对于base版来说,就是$n > 1024$,这也超出了一般读者所能用到的范围了。况且换成线性Attention后,前面关于SA和FFN的计算量结论依然存在,即大部分序列长度下占主导计算量的还是FFN等线性运算,换了线性Attention后也无法感觉到明显的速度提升。所以,总的来说
你要不是成千上万的序列长度,就不要想着换线性Attention了。
再翻翻论文 #
事实上,就算不进行上述分析,只要认真读过关于Attention效率改进相关工作的读者,从论文中的某些图片就可以得到类似的结论:所谓更“高效”的Attention,一般都只适用于成千上万的序列长度,只有在这个场景下性能才有明显提升。
比如较早的工作Sparse Transformers,里边有一张图显示出处理的序列长度都是3000+的:
比如大名鼎鼎的Reformer,演示性能的序列长度都是以K为单位的:
大家颇多好评的Longformer也是如此:
还有Google关于线性Attention的经典之作Performer,显示出哪怕序列长度是$2^{12}=4096$,Performer与Transformer的差距也不能说特别显著:
最后是比较新的工作Luna,提供了一个比较综合的对比表格,同样支持我们的结论:
从已有的各个高效Attention的工作中,我们可以得出结论:这些改进工作所关心的序列长度主要都是以千为单位的,有明显计算效率提升的序列长度基本上都要好几千;当然,我们前面的讨论主要针对的还是时间复杂度,对于空间复杂度,也就是显存占用量,降低的幅度一般要比时间复杂度提升的幅度的要大,但总体而言都是长序列才有价值。
换个期待吧 #
所以,如果你的序列长度还只是一两百,那么就完全不要期望Attention本身的改进了,老老实实换个小点的模型就好。你可以期望未来会有更小的模型能达到同样好的效果,但是不要期望同样大的模型通过修改Attention来提升效率,因为说白了,就算把Attention完全去掉,也提升不了多少性能。
转载到请包括本文地址:https://kexue.fm/archives/8610
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Aug. 09, 2021). 《线性Transformer应该不是你要等的那个模型 》[Blog post]. Retrieved from https://kexue.fm/archives/8610
@online{kexuefm-8610,
title={线性Transformer应该不是你要等的那个模型},
author={苏剑林},
year={2021},
month={Aug},
url={\url{https://kexue.fm/archives/8610}},
}
August 9th, 2021
如何加速FFN是加速Transformer的下一个方向?
FFN加速有的玩吗??
FFN本身比较简单,优化空间不大。
一个自然的想法是缩小FFN的尺寸,但是缩小FFN的尺寸好像会带来明显的效果下降;另一个想法就是提升矩阵乘法的效率,这是一个学界难题,有一定的结果,当前两个$n\times n$的矩阵相乘效率已经推进到$\mathcal{O}(n^{2.37286})$,参考 https://mp.weixin.qq.com/s/OU8ara3IEDj14FnXZU7Bqw ,但是据我所知,这些成果似乎都没正式用到机器学习上。
刚刚相反的是,T5觉得FFN太简单了,于是把它做得更复杂了一些,而且效果还更好了。参考:https://kexue.fm/archives/7867
为何无法应用到机器学习? 计算精度吗?
因为那些都是理论效率,都是在$n$足够大(大到我们100年后可能都用不上这么大的矩阵)的情况下相比传统算法才有明显的效率提升。
FFN就是线性层,加速线性层,然后就可以使用CNN了
August 9th, 2021
苏神有一个笔误,'虽然Transformer的开山之作生成' $\to$ '虽然Transformer的开山之作声称'。
谢谢指出,已经修正。
August 9th, 2021
苏神这个'评估计算量'这个部分还是很惊讶的,从目前使用场景上的确用不到5000以上token的序列数据,的确是不需要纠结线性 Attention 这个问题。
August 9th, 2021
苏神有打算后续的引用格式加上 bibtex 的考虑么?
```
@online{su2021bert,
author = {苏剑林},
title = {线性Transformer应该不是你要等的那个模型},
year = 2021,
howpublished = {科学空间},
url = {https://kexue.fm/archives/8610},
urldate = {2021-08-09}
}
```
感谢你的建议,这就着手补充试试。
August 9th, 2021
苏神好,我们用nvidia profile测了一下,发现即使在token 长度512 的时候, self attention模块的耗时比FFN依然多了三倍, 请问为什么会这样呢?
这可能跟并行程度有关。
FFN虽然大,但是它简单,就两个矩阵乘法,每个矩阵乘法可以充分并行地算;SA至少分了4步:Q/K/V的投影变换、算Attention矩阵、算Attention结果、输出投影,然后Q/K/V的投影变换很多框架还分三步来实现,所以它少则要4步,多则要6步,每步之间是串联的,因此显示出来慢。
最后的输出还有一个投影变换,也是n×hd的矩阵乘以hd×hd的矩阵,请问这个投影变换是说哪个呀?我看https://0809zheng.github.io/2021/07/12/efficienttransformer.html的分析就没有这个投影变换,FFN和MHA之间还有啥我脑子有点懵,希望指点
所有的head的输出concat起来之后,会接一个全连接层,指的就是这个层。如果你看过Attention的实现,那就更容易理解了。就是最后一个Dense(或者叫Linear)层。
nvidia有个fasttransformer,实测512长度时,ffn和sa的耗时占比是相当的,隐藏层规格就是用的bert的规格
感谢反馈。总的来说,理论计算量是那样,但是FFN可以并行优化得更好,所以FFN的速度比估计的要更快些。
September 16th, 2021
剑林你好!
感谢你的分享!这和我们在YOLOS(https://arxiv.org/abs/2106.00666)里面对ViT Scaling Law的讨论不谋而合。
在Image classification中,输入的token数目较少,此时FFN占主导,因此我们猜测之前针对CNN的heuristic的model scaling law依旧适用。
而在object detection中,输入的token数目较多,此时atten占主导,因此需要新的model scaling law。
是的,幸会幸会~
October 12th, 2021
苏神,您好,目前我是刚入门用自注意力机制解决动态图演化的问题,也就是把动态图转化为快照序列来处理,本来是想通过自注意力线性化去改进的,您的网站也提了好多怎么线性化,现在为什么说线性化模型不是要等的模型,按照分析来说,线性attention只适用于长序列,实际应用中可能不需要那么长的序列,那之前提的线性优化升级之路还能有效吗?还是说线性Attention实际应用中也不太能用的到,这样的改进也没有太大的作用
反正对我自己来说,我用不上线性Attention,原因这篇文章已经说了。
October 13th, 2021
苏神有兴趣解读一下这篇文章吗~似乎提出了新的线性Attention并进行了solid的证明。
Choose a Transformer: Fourier or Galerkin
谢谢推荐,我有空看看。
March 12th, 2022
有一个理论和实际的不同,就是按(4)式,只要(hd)固定,计算量就是一样的。但实际训练时,即使固定(hd),h=1和h>1但同比例缩小d,训练速度有明显的不同,h=1速度要快一些。
理论计算量跟实际会有些差别,取决于实现时的可并行程度。纯粹的$a\times b$与$b\times c$的矩阵相乘是最快的。
July 3rd, 2023
序列长度如果达到30000-50000这个范围, 用Transformer的哪个变体效果比较好? FLASH里面的那个GAU可以吗?
如果你追求省显存,那么GAU配合flash attention估计还不错。
谢谢回复!
因为需要对attention matrix进行直接操作, 所以, flash attention可能就用不了了吧.
你说谁需要对attention matrix进行直接操作?