为什么Pre Norm的效果不如Post Norm?
By 苏剑林 | 2022-03-29 | 99090位读者 |Pre Norm与Post Norm之间的对比是一个“老生常谈”的话题了,本博客就多次讨论过这个问题,比如文章《浅谈Transformer的初始化、参数化与标准化》、《模型优化漫谈:BERT的初始标准差为什么是0.02?》等。目前比较明确的结论是:同一设置之下,Pre Norm结构往往更容易训练,但最终效果通常不如Post Norm。Pre Norm更容易训练好理解,因为它的恒等路径更突出,但为什么它效果反而没那么好呢?
笔者之前也一直没有好的答案,直到前些时间在知乎上看到 @唐翔昊 的一个回复后才“恍然大悟”,原来这个问题竟然有一个非常直观的理解!本文让我们一起来学习一下。
基本结论 #
Pre Norm和Post Norm的式子分别如下:
\begin{align}
\text{Pre Norm: } \quad \boldsymbol{x}_{t+1} = \boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t))\\
\text{Post Norm: }\quad \boldsymbol{x}_{t+1} = \text{Norm}(\boldsymbol{x}_t + F_t(\boldsymbol{x}_t))
\end{align}
在Transformer中,这里的$\text{Norm}$主要指Layer Normalization,但在一般的模型中,它也可以是Batch Normalization、Instance Normalization等,相关结论本质上是通用的。
在笔者找到的资料中,显示Post Norm优于Pre Norm的工作有两篇,一篇是《Understanding the Difficulty of Training Transformers》,一篇是《RealFormer: Transformer Likes Residual Attention》。另外,笔者自己也做过对比实验,显示Post Norm的结构迁移性能更加好,也就是说在Pretraining中,Pre Norm和Post Norm都能做到大致相同的结果,但是Post Norm的Finetune效果明显更好。
可能读者会反问《On Layer Normalization in the Transformer Architecture》不是显示Pre Norm要好于Post Norm吗?这是不是矛盾了?其实这篇文章比较的是在完全相同的训练设置下Pre Norm的效果要优于Post Norm,这只能显示出Pre Norm更容易训练,因为Post Norm要达到自己的最优效果,不能用跟Pre Norm一样的训练配置(比如Pre Norm可以不加Warmup但Post Norm通常要加),所以结论并不矛盾。
直观理解 #
为什么Pre Norm的效果不如Post Norm?知乎上 @唐翔昊 给出的答案是:Pre Norm的深度有“水分”!也就是说,一个$L$层的Pre Norm模型,其实际等效层数不如$L$层的Post Norm模型,而层数少了导致效果变差了。
具体怎么理解呢?很简单,对于Pre Norm模型我们迭代得到:
\begin{equation}\begin{aligned}
\boldsymbol{x}_{t+1} =&\,\boldsymbol{x}_t + F_t(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \boldsymbol{x}_{t-1} + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \cdots \\
=&\, \boldsymbol{x}_0 + F_0 (\text{Norm}(\boldsymbol{x}_0)) + \cdots + F_{t-1}(\text{Norm}(\boldsymbol{x}_{t-1})) + F_t(\text{Norm}(\boldsymbol{x}_t))
\end{aligned}\end{equation}
其中每一项都是同一量级的,那么有$\boldsymbol{x}_{t+1}=\mathcal{O}(t+1)$,也就是说第$t+1$层跟第$t$层的差别就相当于$t+1$与$t$的差别,当$t$较大时,两者的相对差别是很小的,因此
\begin{equation}\begin{aligned}
&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1})) \\
\approx&\,F_t(\text{Norm}(\boldsymbol{x}_t)) + F_{t+1}(\text{Norm}(\boldsymbol{x}_t)) \\
=&\, \begin{pmatrix} 1 & 1\end{pmatrix}\begin{pmatrix} F_t \\ F_{t+1}\end{pmatrix}(\text{Norm}(\boldsymbol{x}_t))
\end{aligned}\end{equation}
这个意思是说,当$t$比较大时,$\boldsymbol{x}_t,\boldsymbol{x}_{t+1}$相差较小,所以$F_{t+1}(\text{Norm}(\boldsymbol{x}_{t+1}))$与$F_{t+1}(\text{Norm}(\boldsymbol{x}_t))$很接近,因此原本一个$t$层的模型与$t+1$层和,近似等效于一个更宽的$t$层模型,所以在Pre Norm中多层叠加的结果更多是增加宽度而不是深度,层数越多,这个层就越“虚”。
说白了,Pre Norm结构无形地增加了模型的宽度而降低了模型的深度,而我们知道深度通常比宽度更重要,所以是无形之中的降低深度导致最终效果变差了。而Post Norm刚刚相反,在《浅谈Transformer的初始化、参数化与标准化》中我们就分析过,它每Norm一次就削弱一次恒等分支的权重,所以Post Norm反而是更突出残差分支的,因此Post Norm中的层数更加“足秤”,一旦训练好之后效果更优。
相关工作 #
前段时间号称能训练1000层Transformer的DeepNet想必不少读者都听说过,在其论文《DeepNet: Scaling Transformers to 1,000 Layers》中对Pre Norm的描述是:
However, the gradients of Pre-LN at bottom layers tend to be larger than at top layers, leading to a degradation in performance compared with Post-LN.
不少读者当时可能并不理解这段话的逻辑关系,但看了前一节内容的解释后,想必会有新的理解。
简单来说,所谓“the gradients of Pre-LN at bottom layers tend to be larger than at top layers”,就是指Pre Norm结构会过度倾向于恒等分支(bottom layers),从而使得Pre Norm倾向于退化(degradation)为一个“浅而宽”的模型,最终不如同一深度的Post Norm。这跟前面的直观理解本质上是一致的。
文章小结 #
本文主要分享了“为什么Pre Norm的效果不如Post Norm”的一个直观理解。
转载到请包括本文地址:https://kexue.fm/archives/9009
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 29, 2022). 《为什么Pre Norm的效果不如Post Norm? 》[Blog post]. Retrieved from https://kexue.fm/archives/9009
@online{kexuefm-9009,
title={为什么Pre Norm的效果不如Post Norm?},
author={苏剑林},
year={2022},
month={Mar},
url={\url{https://kexue.fm/archives/9009}},
}
March 29th, 2022
苏神好,感谢分享~ 本人也凑巧在知乎上刷到过这个回答,有恍然大悟之感。
有个问题想讨论一下,当今千亿、万亿级的参数模型一般是Pre-Norm还是Post-Norm呢?(他们是有采用trick来让Post-Norm训练成功还是用了Pre-Norm)。CV、NLP、多模态大规模预训练在这方面有所不同嘛?(ViT系列大模型似乎流行Pre-Norm?)
当前主流大模型更倾向于模型的尺寸(可能更吸引眼球)以及更容易训练,另外在DeepNet之前也没有找到特别漂亮的训练Post Norm模型的技巧,所以主流的大模型都是以“大+Pre Norm”居多。
但其实事后来看,很多百亿甚至千亿的模型,其实都处于“欠训练”状态,也就是当前训练数据或者训练步数都还没能达到大模型的最优点,因此主流大模型的结果其实对Post Norm还是Pre Norm好这个问题并没有参考价值~
感谢苏神耐心回答!受益匪浅~
March 29th, 2022
bert不都是post norm嘛
BERT的含义比较广。
如果你说的是Google开源的BERT,那确实全都是Post Norm;但有些人也倾向于将BERT视为“Transformer+MLM+NSP”,所以不管是Post Norm还是Pre Norm,只要同时用MLM和NSP预训练的都叫做BERT;有些人也倾向于只要用MLM作为主要预训练任务的都叫BERT(不管辅助任务是NSP还是SOP等)。
March 31st, 2022
多年前kaiming he那篇Idendity Mapping里preact-resblock却是更好的,那篇文章的结论可是做了较为充分的对比实验后得出来的。请问苏神,不考虑具体模块差异的话,preact-resblock实质也是把网络深度变浅了,为何表现就没那么糟糕呢?
你说的是resnetv1和resnetv2的区别吧,这个我也有所耳闻。
我没细看它的对比实验究竟是怎么进行的,但如果是“严格控制变量对比”,那实际上是不公平的对比,因为同一超参设置下Post Norm大概率是不如Pre Norm的,但这不能说明Post Norm结构不如Pre Norm,只是因为Post Norm更难训练,需要更精细调整训练策略。
另一方面,五年前的对比实验,对于Post Norm的训练想必也没有像DeepNet这样的技术可以用,因此模型深度上去之后,Post Norm变得更为难以训练,因此效果就更不理想了。
最后,还有一种可能是:本文的观点在于Post Norm更突出模型深度,从而使得效果更好。但问题是原始任务是不是需要这个深度呢?如果它不需要,那么更深就未必效果更好了。
苏神,有可能是因为activation和normalization所起的作用不同吗?kaiming的文章里主要讨论的是pre-activation与post-activation,而这里讨论的是pre-norm与post-norm,有些地方道理是相通的,但应该也有些由activation和norm之间的区别带来的差异吧?
个人觉得activation也算是某种意义上的norm(当然可能有些勉强)。总的来说,即便考虑activation和norm的差异,@苏剑林|comment-18838所说的大部分原因也是成立的,即当时的post效果更差,也有可能是没训练好的原因。
April 22nd, 2022
请问如果要加入dropout的话,preln和postln加dropout的位置一样吗?还是有其他什么讲究?
这个我没研究,都是参考已有的结构来的。
April 26th, 2022
《On Layer Normalization in the Transformer Architecture》这篇文章的结论是移除了warm up阶段之后的pre-norm效果好于post-norm。
如果Post Norm不做Warmup,又不像Deepnet那样做好适当的初始化,那么不如Pre Norm是很自然的事,但这不是什么公平比较,因此不在考虑范围内。
架构的公平比较,应该是在同样的语料和设备之下,大家用各自的方法训练到各自的最优才比较,严格的控制变量都不算公平比较。
October 17th, 2022
苏神请教一下:
pre-norm:x+f(norm(x))
post-norm:norm(x+f(x))
当多层连续使用时,为啥我感觉除了首、尾是否有Norm有差异,中间的过程一模一样呢。如果post-norm的原始数据被norm一下,pre-norm最后一层之后被norm一下这俩就一模一样了吧。。
假如只有两层。
Pre Norm:
$$\begin{aligned}x_1 =&\, x_0 + F(N(x_0)) \\
x_2 =&\, x_1 + F(N(x_1)) = x_0 + F(N(x_0)) + F(N(x_1))\end{aligned}$$
依你,最后一层加上$N$,结果是
$$N(x_0 + F(N(x_0)) + F(N(x_1)))$$
Post Norm:依你,对输入加上$N$
$$\begin{aligned}x_1 =&\, N(N(x_0) + F(N(x_0))) \\
x_2 =&\, N(x_1 + F(N(x_1))) = N(N(N(x_0) + F(N(x_0))) + F(N(x_1)))\end{aligned}$$
你这么有信心,它们两个是恒等式?除去最后一个$N$,它们分别是
$$x_0 + F(N(x_0)) + F(N(x_1))$$
和
$$N(N(x_0) + F(N(x_0))) + F(N(x_1))$$
已经有显著差异了。如果是三层、四层的,差异更大,何来“一模一样”?
我就算加上更宽松的条件$x_0 = N(x_0)$,那么两者也是$x_0 + F(x_0) + F(N(x_1))$和$N(x_0 + F(x_0)) + F(N(x_1))$的区别,能保证$x_0 + F(x_0)$恒等于$N(x_0 + F(x_0))$?
多动手试试,不要看了几下就下结论。
您说的是对的,是我想差了,不好意思。
October 26th, 2022
苏神你好,我有一点没太理解,文中写道“而Post Norm刚刚相反,在《浅谈Transformer的初始化、参数化与标准化》中我们就分析过,它每Norm一次就削弱一次恒等分支的权重,所以Post Norm反而是更突出残差分支的”。但在那篇文章里写的是“这我们可以称为Post Norm结构,...但事实上已经严重削弱了残差本身”,这是否有冲突呢?
这里想表达的是同一个意思,那篇文章说的“严重削弱了残差本身”的“残差本身”想指的是恒等分支。不过为了防止误解,我已经修改了那篇文章的表述了。感谢指出。
May 8th, 2023
或許可以考慮雙殘差連接(https://arxiv.org/abs/2304.14802)
看到了,很简单的融合,测了一下确实也是有效的。
May 9th, 2023
既然 Pre Norm 會加寬 Neural Network,對於一個 L 層(L 足夠大)的由 Pre Norm 所主導的 Neural Network,何不直接用 Neural Tangent Kernel 取而代之?
比赛中我见过有用NTK的,但是一般场景很少见。
June 26th, 2023
https://arxiv.org/abs/2210.06423
Deepnet的后续工作, 进一步提出了subLN以及对应的初始化参数. 是否可以和deepnet结合起来使用, 把deepnet里面的beta换成subLN里面的gamma?
这个架构我好像有印象,但没跟进,谢谢推荐。
看了看,好像就是多一个LN?当时好像感觉不大优雅,所以放弃了。