Lion/Tiger优化器训练下的Embedding异常和对策
By 苏剑林 | 2023-08-28 | 29958位读者 |打从在《Tiger:一个“抠”到极致的优化器》提出了Tiger优化器之后,Tiger就一直成为了我训练模型的“标配”优化器。最近笔者已经尝试将Tiger用到了70亿参数模型的预训练之中,前期效果看上来尚可,初步说明Tiger也是能Scale Up的。不过,在查看训练好的模型权重时,笔者发现Embedding出现了一些异常值,有些Embedding的分量达到了$\pm 100$的级别。
经过分析,笔者发现类似现象并不会在Adam中出现,这是Tiger或者Lion这种带符号函数$\text{sign}$的优化器特有的问题,对此文末提供了两种参考解决方案。本文将记录笔者的分析过程,供大家参考。
现象 #
接下来,我们的分析都以Tiger优化器为例,但分析过程和结论同样适用于Lion。
首先,笔者观察到的现象是:
1、部分Token的Embedding分量变成了$\pm 100$;
2、还有一小部分Token的Embedding分量正在趋于$\pm 100$;
3、这些token看上去都是相当低频的token;
4、整个Embedding矩阵的最大值就是100,最小值就是-100;
5、除Embedding外,其他权重没有这个问题;
6、模型的总体表现(比如训练Loss、生成测试)都正常。
可能有读者想问,既然模型表现正常,那还管它干嘛呢?在笔者看来,至少有两方面的原因:第一,如果后面想要微调,有可能某些低频Token重新变得高频,如果这些Token的Embedding太糟糕,那么微调也救不回来;第二,有些能力在Loss体现不出来,比如中英的预训练模型,通常因为训练语料夹杂着非常少的多语种语料,就体现出一定的多语种能力,很明显这种能力依赖于低频Token的Embedding质量,如果被优化器所连累而失去这种能力,就“亏大发”了。
当然,不管是什么优化器,都有可能训着训着就把模型训崩了,这并不让人意外,很多时候也难以深究。但这里最耐人寻味的地方是“崩”得这么有规律——刚好是整齐的$\pm 100$,这不能不让笔者想要进一步找出它背后的原因。
思考 #
根据以上观察结果,初步可以得出这些异常值只出现在“低频Token的Embedding”上,这让笔者不禁联想到《Keras实现两个优化器:Lookahead和LazyOptimizer》讨论过的带动量的优化器会导致Embedding层过度优化问题。
具体来说,只要一个token出现过,那么该token的Embedding对应的动量就被更新为非零(假设该token的梯度不会正好是零),于是在后面的更新中,即便当前样本没有出现过该token(梯度为零),但该token的Embedding依然会被更新(动量不为零),这就是低频token的过度优化问题。这个问题会出现在所有带动量的优化器中,包括Adam和Tiger,不过在Adam中,这可能不会有明显感知,因为Adam的更新量跟动量成正比,如果一个token长期不重复出现,那么动量就会指数下降,所以很快就趋于零了,换句话说更新量也很快趋于零,即过度更新很快就会消失。
然而,在Tiger中情况有点不一样。Tiger的更新量是跟动量的符号函数$\text{sign}(\boldsymbol{m}_t)$成正比,尽管动量$\boldsymbol{m}_t$会指数下降,但符号函数不会,在$\boldsymbol{m}_t$由于舍入误差变成0之前,$\text{sign}(\boldsymbol{m}_t)$都保持$\pm 1$的值不变,也就是更新量一直都是常数,所以Tiger的Embedding过度更新问题更加严重。“屋漏偏逢连夜雨”的是,一个token的Embedding由于过度更新偏向了某个方向之后,它的梯度可能会适应并助长这种变化,也就是说下一次它出现时的梯度是同一方向而不是相反方向,这就导致了它长期在同一方向上过度更新,最终导致了异常值。
计算 #
那么异常值为什么偏偏是$\pm 100$呢?这就要邀请权重衰减登场了。Tiger总的优化公式是:
\begin{equation}\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t \left[\text{sign}(\boldsymbol{m}_t) + \lambda \boldsymbol{\theta}_{t-1}\right]\end{equation}
也就是说,除了动量的符号函数外,还有一个权重衰减项。在文章开头提到的异常实验中,衰减率$\lambda$设为了0.01。
不难发现,如果$\text{sign}(\boldsymbol{m}_t)$长期为常量,那么上述迭代公式将会有一个平衡点,它出现在$\text{sign}(\boldsymbol{m}_t) + \lambda \boldsymbol{\theta}^*=\boldsymbol{0}$时,即
\begin{equation}\boldsymbol{\theta}^* = -\frac{\text{sign}(\boldsymbol{m}_t)}{\lambda}\end{equation}
这正好对应一个元素是$\pm 100$的向量,这就解释了异常值为$\pm 100$的结果。如果有兴趣,读者还可以假设$\eta_t$也是常数,那么可以直接求出$\boldsymbol{\theta}_t$的解析式,从而进一步分析收敛速度等。这里笔者就不继续展开了。
对策 #
既然问题出现在对低频Token的Embedding的过度更新,那么一个自然的解决方案就是像《Keras实现两个优化器:Lookahead和LazyOptimizer》所提的那样,将Embedding的更新Lazy化,即只有当Token出现过的时候,才更新相应的Embedding。如果能获取到所有的输入Token Ids的集合,那么直接只更新这些Token的Embedding即可,如果不能,我们可以通过判断Embedding的梯度模长是否非零,来判断该Embedding是否需要被更新。
另一方面,从更一般的视角看,该问题是Lion/Tiger优化器对于梯度稀疏的参数的共同缺陷,包括但不限于Embedding层。于是,解决问题的另一个思路是将Embedding的梯度变得不再稀疏,为此我们可以考虑Tied Embeddings,即输入和输出的Embedding共享,这样由于输出端重用了整个Embedding矩阵,因此整个Embedding矩阵都有非零梯度,从而让$\boldsymbol{m}_t$不至于长期为常量。当然Tied Embedding可能会带来另外的一些问题,相应的解决方案可以参考《语言模型输出端共享Embedding的重新探索》。在笔者的实验中,使用将模型特征的channels对半交换的Tied Embedding,能解决以上问题,并且效果似乎比Untied Embedding还要好一点。
最后,笔者也就此问题请教了Lion优化器的作者,得到的回复是他们之前也留意到了这个问题,他们的解决方案是混合优化器,比如Embedding层就用Adam,其他层才用Lion/Tiger。呃,这个解决方案是笔者没想到的,感觉不是特别优雅,但也确实能解决,读者自行选择就好。
小结 #
本文介绍了Lion/Tiger优化器训练下的Embedding异常现象,并分析了背后的原因,最后给出了参考的解决方案。
转载到请包括本文地址:https://kexue.fm/archives/9736
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Aug. 28, 2023). 《Lion/Tiger优化器训练下的Embedding异常和对策 》[Blog post]. Retrieved from https://kexue.fm/archives/9736
@online{kexuefm-9736,
title={Lion/Tiger优化器训练下的Embedding异常和对策},
author={苏剑林},
year={2023},
month={Aug},
url={\url{https://kexue.fm/archives/9736}},
}
August 28th, 2023
可以考虑将 sign 换成一个分段线性函数么,比如说 -1 (x < -1), x (-1 < x < 1), 1 ( x > 1)。然后这个线性的宽度可以随着 m 的增大而增大,这样就可以在 m 比较大的时候做一个衰减,而比较小的时候就退化回 lion。而且这种简单的线性计算也不会增加太多计算量。
你这个就是clip操作了,直接替换的话不大行,因为一般情况下动量就没几个分量是超出$[-1,1]$的,所以直接替换的后果是退化为带动量的SGD。一个改进方式是先将动量乘以一个大的数字,然后再clip,但这样又引入了一个新的参数。
那可不可以在$\lambda$那里加一个和模值有关的衰减惩罚呢,比如说$\lambda \to|\theta_{t-1}|\lambda$,这样按照公式$(2)$进行计算后,会得到$|\theta^*|\theta^*=\pm 1/\lambda$,模值大概 0.1 左右。
不过不知道这样对$\lambda$的处理会出什么奇怪的问题么。
不好意思,这里的结果写错了,而且我又想了想感觉也不太靠谱,不过还有另外一个问题想请教苏老师。
如果模型学成这样还不太影响结果的话,能不能反其道而行之,稍微再鼓励鼓励这个事。如果 embedding 的各个维度在一定程度上都表现为$\pm 100$还能表现的不错,那么这种情况下的每个维度可能会有更强的可解释性?
避免出现异常值的出发点是保留二次训练的可能,如果只看预训练效果,这个现象没有明显影响。
感觉上,这个东西很难翻过来利用?因为它出现在低频、动量符号长期不变的场景,并没有对梯度本身的规律有所限制。
August 28th, 2023
[...]Read More [...]
September 6th, 2023
最近我们正在寻找一位【北京海淀】生信AI高级工程师的候选人,不知道您身边有没有厉害的朋友或者师弟师妹呀?
我的email:jakexue@halibre.com
工作职责:
1. 熟悉经典的生物信息算法,配合集群软硬件资源对已有生物信息算法进行性能调优。
2. 参与生物信息平台核心业务系统模块的功能设计和开发。
3. 对现有工具性能进行优化,版本迭代更新。
4. 辅助AI生物信息科学家和算法分析师完成数据整理和清洗。
任职要求:
1. 熟悉常规生物信息数据资源的使用和分析。
2. 熟练掌握Linux, Python, Java/C++, R。
3. 对生物信息学领域的经典算法和工具有较为深入的理解,有工作经验者优先。
4. 具有良好的沟通能力,有极强的学习能力,有上进心,有理想。
暂无,抱歉。
September 7th, 2023
苏神,想问以下为什么很多生成损失用的是L1 loss,而不用L2 loss呢
你说的是图像?好像是有助于提高清晰度,减少模糊吧。
September 13th, 2023
训练时每次前向的时候 对lm head的weight 做normalize 会有帮助吗
应该没有。
December 4th, 2023
苏神,可以谈谈Untied embedding吗?
谈哪方面呢?https://kexue.fm/archives/9698 这个不够?
January 9th, 2024
用powerball代替sign应该可以解决吧?Tiger:一个“抠”到极致的优化器中的公式8
查了下,就是换成$\text{sign}(x)$换成$\text{sign}(x) \times |x|^{\gamma}$,这个我倒是想过,问题是多了个不好调的参数啊...
gamma不是0的话可能会崩得更惨,有一个scale-invariant条件,不满足的话模型可能学不好
这是一个基本形式,我一般还会加Lamb优化器的weight-norm scale,即假设优化器出来的基本更新向量是$\boldsymbol{u}$,我会选择$\frac{\boldsymbol{u}}{\Vert\boldsymbol{u}\Vert}\times \Vert\boldsymbol{\theta}\Vert$作为最终的更新量。
August 16th, 2024
用RMSProp或者Adam更新嵌入层参数,Lion更新剩余参数,就解决了嘛?
可以解决