我们可以无损放大一个Transformer模型吗(一)
By 苏剑林 | 2021-06-02 | 65821位读者 |看了标题,可能读者会有疑惑,大家不都想着将大模型缩小吗?怎么你想着将小模型放大了?其实背景是这样的:通常来说更大的模型加更多的数据确实能起得更好的效果,然而算力有限的情况下,从零预训练一个大的模型时间成本太大了,如果还要调试几次参数,那么可能几个月就过去了。
这时候“穷人思维”就冒出来了(土豪可以无视):能否先训练一个同样层数的小模型,然后放大后继续训练?这样一来,预训练后的小模型权重经过放大后,就是大模型一个起点很高的初始化权重,那么大模型阶段的训练步数就可以减少了,从而缩短整体的训练时间。
那么,小模型可以无损地放大为一个大模型吗?本文就来从理论上分析这个问题。
含义 #
有的读者可能想到:这肯定可以呀,大模型的拟合能力肯定大于小模型呀。的确,从拟合能力角度来看,这件事肯定是可以办到的,但这还不是本文关心的“无损放大”的全部。
以BERT为例,预训练阶段主要就是一个MLM模型,那么“无损放大”的含义就是:
是否可以通过某种变换,把一个小模型直接变换成一个大模型,并且输出完全不改变?
这里的变换,指的是对权重做一些确定性的变换,而不用通过梯度下降来继续训练;输出完全不改变,指的是对于同一个输入,小模型和大模型给出的预测结果是完全一致的,也就是说它们表面上看起来不一样,但数学上它们是完全一致的函数,所以称为“无损放大”。由于是无损放大,我们至少可以保证大模型不差于小模型,所以继续预训练理论上有正的收益。至于先小后大这样预训练在效果上能不能比得上一开始就从大训练,这个需要实验来确定,并不是本文关心的问题。
直觉来想,这种放大也不困难,比如通过“重复”、“补零”等操作就可以实现模型权重的自然放大。事实上尝试的方向也是如此,但难点在于我们需要仔细分析模型的每一个模块在被放大之后所产生的后果,以确保最终的结果是无损的。
尝试 #
下面我们以“将一个BERT放大为2倍”为例子进行分析尝试,来确定最终的变换形式。这里的“放大”指的是仅仅扩大隐层向量的维度,并不改变模型的层数,也不改变多头注意力机制的头数。
Embedding #
首先,输入层是Embedding层,因此先要解决的是Embedding层的放大问题。这也是其中最简单的一环,就是直接将每个token的向量维度都放大为2倍即可,主要就是“重复”、“补零”两种操作:
重复:[x1,x2,x3,x4]→[x1,x1,x2,x2,x3,x3,x4,x4]补零:[x1,x2,x3,x4]→[x1,x2,x3,x4,0,0,0,0]
两种方案都可以作为候选方案,但直觉上来想,补零这种方式引入了太多的零,会导致过度稀疏和同一个值重复次数过多,不利于权重的多样性,因此我们还是选择了重复这种方案。不过,就算只看重复,也不指上述一种方式,比如[x1,x2,x3,x4,x1,x2,x3,x4]也是一种方案,但后面关于Attention层的分析表明,后一种方案是不可取的。
除此之外,我们通常还希望变换是正交的,这通常能最大程度上保证模型的稳定性,具体来说,正交变换的最基本性质是不改变向量的模型,所以我们将最终的重复变换调整为:
(x1x2⋮xd)→(˜x1˜x2˜x3˜x4⋮˜x2d−1˜x2d)=1√2(x1x1x2x2⋮xdxd)
或者简记成˜xi=x⌈i/2⌉/√2,其中⌈⋅⌉是上取整运算,我们称之为“重复再除以√2”。
LayerNorm #
Embedding的下一层就是LayerNorm了,变换前,LayerNorm的运算为
yi=xi−μσ×γi+βiμ=1dd∑i=1xiσ=√1dd∑i=1(xi−μ)2
而变换后,我们有
˜μ=12d2d∑i=1˜xi=1dd∑i=1xi√2=μ√2˜σ=√12d2d∑i=1(˜xi−˜μ)2=√1dd∑i=1(xi√2−μ√2)2=σ√2˜xi−˜μ˜σ=x⌈i/2⌉/√2−μ/√2σ/√2=x⌈i/2⌉−μσ
这也就是说,“减均值除以标准差”这一步自动帮我们消去了1/√2这个因子,其结果是放大前结果的直接重复。如果我们将参数向量β,γ也按照公式(2)进行变换,那么结果将是˜yi=y⌈i/2⌉/√2,跟Embedding层的变换结果一致,而我们就是要尽量使得每一层“净变换”都是同样的一个简单变换:“重复再除以√2”。
FeedForward #
按照顺序,接下来本来应该分析Attention层才对,不过FeedForward层相对简单一点,并且FeedForward层的分析结果也对后面理解Attention层的变换有所帮助,因此这里先来考虑FeedForward层的变换。
FeedForward层只是两个全连接层的复合,所以我们只需要分析单个全连接层:
yj=A(d∑i=1xiwi,j+bj)
这里的A(⋅)是激活函数。鉴于之前的经验,我们尝试如下变换
˜wi,j=12w⌈i/2⌉,⌈j/2⌉,˜bj=1√2b⌈j/2⌉
也就是将bj按照式(2)进行变换,而对于wi,j则尝试使用形式下述变换:
(w1,1w1,2⋯w1,Dw2,1w2,2⋯w2,D⋮⋮⋱⋮wd,1wd,2⋯wd,D)→12(w1,1w1,1w1,2w1,2⋯w1,Dw1,Dw1,1w1,1w1,2w1,2⋯w1,Dw1,Dw2,1w2,1w2,2w2,2⋯w2,Dw2,Dw2,1w2,1w2,2w2,2⋯w2,Dw2,D⋮⋮⋮⋮⋱⋮⋮wd,1wd,1wd,2wd,2⋯wd,Dwd,Dwd,1wd,1wd,2wd,2⋯wd,Dwd,D)
这里的D就是输出维度大小,这里我们假设模型放大2倍后,D也放大2倍。不难看出,该变换其实就是对变换矩阵wi,j行列两个方向都分别执行变换(2)。此时
2d∑i=1˜xi˜wi,j+˜bj=2d∑i=1xi√2wi,⌈j/2⌉2+b⌈j/2⌉√2=1√2(d∑i=1xiwi,⌈j/2⌉+b⌈j/2⌉)
这说明变换(6)对于线性变换层来说,能够满足我们的理想追求——放大后的结果就是“重复再除以√2”。然而,这还不够,因为全连接层还有个激活函数A(⋅),现在的问题在于A(x/√2)未必等于A(x)/√2,而如果不等,我们就没法让整体的变换等价于“重复再除以√2”。
事实上,BERT用的GeLU激活函数就不满足该恒等式;线性激活函数(不加激活函数)显然是满足这个等式的,而满足这个等式一个常见的非线性激活函数便是ReLU(也包括LeakyReLU)函数,因此一个直接的解决方式就是FeedForward层换用ReLU激活函数。事实上,这也已经是预训练模型的一个常见选择了,百度的Ernie和Google的T5模型,它们的FeedForward层激活函数都是用ReLU。
那么,像BERT这样的非ReLU激活函数的FeedForward层就没办法了吗?那也不至于,因为FeedForward层是两个全连接层的复合,我们只需要在变换第一个全连接的时候少除以一个√2,变换第二个全连接的时候多除以一个√2就行了。具体来说,第一个全连接权重变为:
˜wi,j=1√2w⌈i/2⌉,⌈j/2⌉,˜bj=b⌈j/2⌉
此时就有
A(2d∑i=1˜xi˜wi,j+˜bj)=A(d∑i=1xiwi,⌈j/2⌉+b⌈j/2⌉)
此时结果就是原结果的直接重复,没有除以√2,既然如此,后面紧接着的全连接层多除以一个√2就行了,即后面的全连接层权重变换为
˜wi,j=12√2w⌈i/2⌉,⌈j/2⌉,˜bj=12b⌈j/2⌉
这样整个FeedForward层的效果就等价于“重复再除以√2”了。
Attention #
现在到了最难啃的“硬骨头”——Attention层的变换。Attention层首先通过三个线性层将每个输入向量变换为q,k,v:
qj=d∑i=1xiw(q)i,j+b(q)j,kj=d∑i=1xiw(k)i,j+b(k)j,vj=d∑i=1xiw(v)i,j+b(v)j
根据前面对FeedForward层的分析可以得知,如果要想q,k,v都达到“重复再除以√2”的效果,只需要按照变换(6)进行。但Attention层不是单纯的全连接层,变换完之后,我们要检查Attention矩阵是否不变,我们来算内积:
2d′∑i=1˜qi˜ki=2d′∑i=1qi√2ki√2=d′∑i=1qiki
其中d′是对应的head_size。这个结果告诉我们,上述变换保持了内积不变,所以应该也保持Attention矩阵不变。但是,这里有一个陷阱!如果是T5这样的模型,它的内积之后是没有尺度缩放的,所以这样的确完事了;然而像BERT这样的模型,它是内积之后除了个√d′再做Softmax的,,而一旦放大模型后,除以√d′变成了除以√2d′,内积不变也不能保持Attention矩阵不变,而应当还需要往q,k的权重分别再乘个4√2,所以最终的变换应该是
˜w(q)i,j=4√22w(q)⌈i/2⌉,⌈j/2⌉,˜b(q)j=4√2√2b(q)⌈j/2⌉˜w(k)i,j=4√22w(k)⌈i/2⌉,⌈j/2⌉,˜b(k)j=4√2√2b(k)⌈j/2⌉˜w(v)i,j=12w(v)⌈i/2⌉,⌈j/2⌉,˜b(v)j=1√2b(v)⌈j/2⌉
经过这样变换后,Attention矩阵不变,而˜vi=v⌈i/2⌉/√2,所以最终的输出结果也是˜oi=o⌈i/2⌉/√2。
上述内容只是针对Attention的单个头进行分析,事实上Attention有多个头,多个头的输出结果还要拼接起来再接一个全连接层。当然,由于每个头都是平等的、独立的,因此上述结论基本不变,最后全连接层也只需要按照式(6)进行变换,就可以让Attention的变换效果。但是,多头带来的一个效应是,我们在重复的时候,必须局部地进行重复。
具体来说,我们在实现多头的时候,并非是真的做了多个全连接运算,而是做了一个大的全连接运算后再reshape,这样一来我们可以比较两种不同的重复方式的reshape结果:
[x1,x2,x3,x4,x5,x6][x1,x2,x3,x4,x5,x6]↓↓[x1,x1,x2,x2,x3,x3,x4,x4,x5,x5,x6,x6][x1,x2,x3,x4,x5,x6,x1,x2,x3,x4,x5,x6]↓↓(x1,x1,x2,x2x3,x3,x4,x4x5,x5,x6,x6)(x1,x2,x3,x4x5,x6,x1,x2x3,x4,x5,x6)
注意放大前reshape结果是(x1,x2x3,x4x5,x6),所以对比两种不同的重复方式的reshape结果,我们发现第二种重复方式reshape之后的结果全乱了,不等价于每个头分别重复。因此我们只能选择前一种重复方式。
输出概率分布 #
通过以上分析,我们可以使得整个Encoder在放大到2倍之后,实现“重复再除以√2”的效果。最后剩下的就是输出部分,即将Encoder的输出向量转化为token的概率分布,这里边包含几种情况。
像GPT、T5等模型,它们是直接在Encoder输出后面乘以了Embedding矩阵的转置来做作为概率分布的logits(当然有可能还有个偏置),由于Embedding矩阵本身就包含了“重复再除以√2”的操作,而Encoder的输出也是“重复再除以√2”,两者结合刚好抵消,所以从概率分布角度看,输出是完全不变的。
不过BERT多了一层全连接,也就是说它先接了一个GeLU激活的全连接层,然后才乘以Embedding矩阵的转置并加上偏置项作为logitis。在“FeedForward”那一节我们已经讨论了,非ReLU激活的全连接层无法实现“重复再除以√2”的效果,而只能通过变换(9)来实现单纯的“重复”效果,所以为了再达到“除以√2”的效果,它后面接的LayerNorm在变换的时候就要多除以一个√2了。
当然,如果是ReLU激活,那么按照式(6)来变换,那么可以实现完全不改变了。此外,如果是像mT5那样,最后转为logits的变换矩阵跟Embedding层不共享,那么也可以通过调整最后的变换矩阵来实现输出的完全不变。
RoPE位置编码 #
前面的分析都只适用于每个神经元都是不相关的情形,也就是说向量的任意两个分量xi,xj是没啥关联的。但如果我们在模型中用了“旋转式位置编码(RoPE)”,那么这个假设就不成立了,因为RoPE是以每两个分量为一组进行变换的,即[x1,x2]为一组、[x3,x4]为一组,依此类推。
如果还是按照之前式(2)进行重复变换,那么变换之后就变成了[x1,x1]为一组、[x2,x2]为一组、...,跟原来的分组不一致,所以会带来很大的偏差。这种情况下,重复的时候也应当按照两个为一组来进行:
[x1,x2,x3,x4,⋯,xd−1,xd]↓1√2[x1,x2,x1,x2,x3,x4,x3,x4,⋯,xd−1,xd,xd−1,xd]
当然,由于默认的RoPE是没有可训练权重的,它是按照固定的方式进行渐变的,所以哪怕按照该方式进行重复,那不能完全保证结果一致。也就是说,如果使用了RoPE,那么基本上不能实现无损放大。不过实际测试结果表明,按照该方式进行重复放大后,对应的RoFormer虽然性能有所损失,但不多,可以很快通过继续训练恢复。
结论 #
现在我们可以确认,对于BERT来说,如果非线性激活函数用ReLU,那么BERT是可以直接无损放大的,如果非线性激活函数不是ReLU,那么可以实现MLM准确率无损的放大(事实上经过更精细的调整,也可以实现完全无损放大,但每个层的变换有点不统一了,不够优雅);对于GPT、T5等模型来说,不管激活函数用啥(包括mT5用的GLU激活,也可以定制适当),其实都可以实现无损放大。
其中,将BERT权重进行放大为2倍的变换汇总如下:
Embedding˜xi=1√2x⌈i/2⌉LayerNorm˜βi=1√2β⌈i/2⌉,˜γi=1√2γ⌈i/2⌉Attention˜w(q)i,j=4√22w(q)⌈i/2⌉,⌈j/2⌉,˜b(q)j=4√2√2b(q)⌈j/2⌉˜w(k)i,j=4√22w(k)⌈i/2⌉,⌈j/2⌉,˜b(k)j=4√2√2b(k)⌈j/2⌉˜w(v)i,j=12w(v)⌈i/2⌉,⌈j/2⌉,˜b(v)j=1√2b(v)⌈j/2⌉˜w(o)i,j=12w(o)⌈i/2⌉,⌈j/2⌉,˜b(o)j=1√2b(o)⌈j/2⌉FeedForward˜w(1)i,j=1√2w(1)⌈i/2⌉,⌈j/2⌉,˜b(1)j=b(1)⌈j/2⌉˜w(2)i,j=12√2w(2)⌈i/2⌉,⌈j/2⌉,˜bj=12b(2)⌈j/2⌉输出概率分布˜wi,j=1√2w⌈i/2⌉,⌈j/2⌉,˜bj=b⌈j/2⌉
如果是其他略有不同的模型,那么就模仿前面的思想进行类似的分析即可。如果是RoPE,那么将重复的方案改为式(15)就好;如果是扩大k倍,那么将表格中的多数2换为k就好。简单来说,如果Attention没有尺度缩放(除以√d′),以及FeedForward的激活函数是ReLU(或者LeakyReLU),那么放大k倍的变换就最简单的,将权重的每一维都执行“重复k次并除以√k”就好了。
小结 #
本文从数学上分析了直接放大Transformer模型的可能性,最终得到了若干可用的变换,确定了无损放大Transformer模型的可行性,为实现大模型的渐进式训练提供了参考思路。
转载到请包括本文地址:https://kexue.fm/archives/8444
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 02, 2021). 《我们可以无损放大一个Transformer模型吗(一) 》[Blog post]. Retrieved from https://kexue.fm/archives/8444
@online{kexuefm-8444,
title={我们可以无损放大一个Transformer模型吗(一)},
author={苏剑林},
year={2021},
month={Jun},
url={\url{https://kexue.fm/archives/8444}},
}
June 2nd, 2021
苏神之后会考虑上实验吗~
线下验证过这样的放大确实能达到无损的效果,但是要完整实验放大之后继续预训练是否能比得上从零预训练大模型的效果,那成本太大了,目前没有做。
June 3rd, 2021
这个实验有一个BUG
反向传播的时候,相同的权重会吃到相同的梯度,于是进行相同的更新
这将导致,你的缩放模型事实上只缩放了学习率参数/wd之类的东西
如果想放大,应该补随机数,而不是补0(全0得到相同的梯度,除非你上dropout),更不应该补原权重(除非你上dropout)。
另外,建议对矩阵,行补0,列补随机数
这样,按列补的随机数会被按行补的0消掉,不影响后面的计算,且能保证得到一个非0且各不相同的梯度
大概思路是这样,具体实现什么的实在懒得扣细节
毕竟玩不起transformer
谢谢建议,如果要不同梯度,可以考虑在本文结果的接触上给每个权重加点随机噪声。本文主要是论证无损放大的存在性。
June 3rd, 2021
我投CIKM2021的文章是做这个,思路几乎一样,关于训练,我发现补随机值比补零收敛快很多。 整体来说,相比从头训练,还是可以节省很多时间。
https://arxiv.org/abs/2104.11390v1 请问下是这篇么?上传到arxiv的版本是不是简化过了啊?
June 3rd, 2021
另外,我证明了LayerNorm层无法做到数学等价的扩大,只能近似
你证明了对于什么变换无法做到扩大?任意变换?那本文的结果错了吗?
另外,贵作有电子版可以拜读一下吗
CIKM2021刚投稿,肯定是不能给的吧hhh
你这个证的应该没错,因为是扩大整数倍。扩大N倍可以通过复制N倍保证均值和方差不变,所以layernorm应该可以等效。我证明的是不是整数倍的情况。
不是整数倍的情况下,我也证明了可行了。估计下周能写一篇文章,推广到一般的正交变换。
请问苏神啥时候写第二篇啊,期待ing
后来发现了,此路其实不通,所以就放弃了~后面我抽时间写篇文章说说为啥不通。
好吧,不知道苏神啥时候有空讲讲为啥不通啊,能否评论区简单讲讲?
感谢~~
简单来说,如果我们将一个小模型训练到了最优,那么理想情况下它进入了一个局部最小值。可以证明的是,如果通过正交变换将模型无损放大,那么放大后的模型依然处于局部最小值,这意味着梯度下降不会带来提升了(局部最小值梯度为0,模型基本不动)。
如果想要继续能启动训练,那么需要的trick就比较多,比如小模型不能训练到最优,或者加扰动,或者GradMax的方案,总之会比较复杂。
另外,从工程上来讲,虽然这个先小后大的循环看起来比较诱惑,但实际上不会节省太多时间。举个例子,假设1B的模型的最优目标是50,10B的模型的最优目标是60,实际上10B模型从0~50这段时间,远小于50~60这段时间,所以先小(1B)后大(10B)的训练方案,稍微节省了0~50的时间,但总体节省幅度不多,反而导致工程复杂化。
June 8th, 2021
bert 的输入是 position + segment + word 三者的embedding之和, 放大时是对三者和进行放大的吧? 这样做的话对position信息应该会有损失吧? 对于后续的fintune 应该会有影响吧
不知道你想到哪里去了。三者分别用同一种方式放大不就行了?有什么问题么?
June 10th, 2021
我是CIKM2021的PC,我的paper id是多少,我关注一下
September 21st, 2022
了解下fpi
FPI 是什麼的縮寫?
May 16th, 2024
感谢苏神分享,2024年回过头来看这篇文章依旧收益许多。
这篇文章还是有不少问题的,因为没有任何随机性的复制式恒等变换,会导致梯度也是相同的,最终无法学出不同的权重。正确的方式应该要想办法引入带有随机性的权重,同时满足恒等性。