我们知道,在RoPE中频率的计算公式为θi=b2i/d,底数b默认值为10000。目前Long Context的主流做法之一是,先在b=10000上用短文本预训练,然后调大b并在长文本微调,其出发点是《Transformer升级之路:10、RoPE是一种β进制编码》里介绍的NTK-RoPE,它本身有较好长度外推性,换用更大的b再微调相比不加改动的微调,起始损失更小,收敛也更快。该过程给人的感觉是:调大b完全是因为“先短后长”的训练策略,如果一直都用长文本训练似乎就没必要调大b了?

上周的论文《Base of RoPE Bounds Context Length》试图回答这个问题,它基于一个期望性质研究了b的下界,由此指出更大的训练长度本身就应该选择更大的底数,与训练策略无关。整个分析思路颇有启发性,接下来我们一起来品鉴一番。

期望性质 #

RoPE这里就不再详细介绍了,它本质上是一个分块对角矩阵
Rn=(cosnθ0sinnθ00000sinnθ0cosnθ0000000cosnθ1sinnθ10000sinnθ1cosnθ1000000cosnθd/21sinnθd/210000sinnθd/21cosnθd/21)


然后利用恒等式
(Rmq)(Rnk)=qRmRnk=qRnmk

q,k注入绝对位置信息,并自动实现了相对位置的效果。其中θi=b2i/d,这里的b的取值就是本文要探讨的问题。

除了给模型注入位置信息外,我们期望RoPE能具备两个理想性质,以达到更好的效果:1、远程衰减,即位置相近的Token平均来说获得更多的注意力;2、语义聚合,即语义相似的Token平均来说获得更多的注意力。其中第一点我们早在《Transformer升级之路:2、博采众长的旋转式位置编码》有过相关讨论,RoPE确实有一定的远程衰减性质。

所以接下来我们来分析第二点。

不等关系 #

所谓语义聚合,指的是当kq相近时,不管它们的相对距离nm多大,其注意力qRnmk平均来说都应该更大(至少要比随机的两个Token更大)。为了得到一个量化的结论,我们进一步简化问题,假设q的每个分量都是独立同分布的,每个分量的均值为μ,方差为σ2

现在我们考虑两种不同的k:一种是在q的基础上,加上一个零均值的扰动ε,我们记˜k=q+ε,代表跟q语义相近的Token;另一种则是假设kq独立同分布,这代表两个随机的Token。根据第二点理想性质,我们希望有
Eq,k,ε[qRnm˜kqRnmk]0


注意我们刚才反复强调了“平均来说”,意味着我们只是期望一个平均的趋势,而不是每一点都能严格成立,所以我们在上式加了取数学期望Eq,k,ε。现在根据假设和RoPE的定义,我们可以把上式具体地算出来:
Eq,k,ε[qRnm(q+ε)qRnmk]=Eq[qRnmq]Eq,k[qRnmk]=Eq[qRnmq]Eq[q]RnmEk[k]=Eq[qRnmq]μ21Rnm1=Eq[d/21i=0(q22i+q22i+1)cos(nm)θi]d/21i=02μ2cos(nm)θi=d/21i=02(μ2+σ2)cos(nm)θid/21i=02μ2cos(nm)θi=d/21i=02σ2cos(nm)θi

如果训练长度最大为L,那么nmL1,因此第二点理想性质可以用如下不等式近似描述:
d/21i=0cosmθi0,m{0,1,2,,L1}

其中L是最大长度,是训练前就要选定的超参,而d是模型的head_size,按照LLAMA的一般设置是d=128,这也就意味着,上式的唯一可调参数就是θi=b2i/d中的b。在《Transformer升级之路:1、Sinusoidal位置编码追根溯源》中我们就简单探究过这个函数,它整体趋势是衰减的,b越大则衰减速度越慢,对应的连续非负区间就越大,所以存在一个最小的b使得上述不等式恒成立,即
b=inf{b|fb(m)d/21i=0cosmb2i/d0,m{0,1,2,,L1}}

数值求解 #

由于fb(m)涉及到多个三角函数的求和,并且θi关于i还是非线性的,很难想象上述问题会有解析解,因此只能诉诸数值求解了。然而,fb(m)越到后面震荡越频繁且不规律,因此即便数值求解也不是那么简单的事情。

笔者一开始以为,如果b0使得fb0(m)0恒成立,那么bb0都恒成立fb(m)0,所以用二分法就可以了。但事实上这个假设并不成立,所以二分法宣告破产。继续想了一段时间,依然没什么优化思路,期间向原论文作者请教过,他们采用的是逆函数法,即给定b求使得fb(m)0恒成立的最大L是比较简单的,于是我们可以得到很多(b,L)对,理论上只要枚举的b足够多,那么对于任意L都可以找出最小的b。然而这里有个精度问题,原论文最大的L计算到了106b至少要枚举到108,如果枚举间隔小,那么计算成本非常大,如果枚举间隔大,那么可能漏掉很多解。

最后,笔者决定还是用“Jax + GPU”进行暴力搜索,以求得到更高精度的结果,大致流程是:

1、初始化b=1000L(在106b=1000L可以使得fb(m)0恒成立);

2、遍历k=1,2,3,4,5,执行以下操作:

  2.1)将[0,b]等分为10k份,遍历等分点,判断fb(m)0是否恒成立;

  2.2)取最小的使得fb(m)0恒成立的等分点,更新b

3、返回最终的b

最终结果普遍要比原论文的更紧一些
L1k2k4k8k16k32k64k128k256k512k1Mb(原文)4.3e31.6e42.7e48.4e43.1e56.4e52.1e67.8e63.6e76.4e75.1e8b(本文)4.3e31.2e42.7e48.4e42.3e56.3e52.1e64.9e62.4e75.8e76.5e7

参考代码:

from functools import partial
import numpy as np
import jax.numpy as jnp
import jax

@partial(jax.jit, static_argnums=(2,))
def f(m, b, d=128):
    i = jnp.arange(d / 2)
    return jnp.cos(m[:, None] * b ** (-2 * i[None] / d)).sum(axis=1)

@np.vectorize
def fmin(L, b):
    return f(np.arange(L), b).min()

def bmin(L):
    B = 1000 * L
    for k in range(1, 6):
        bs = np.linspace(0, 1, 10**k + 1)[1:] * B  
        ys = fmin(L, bs)
        for b, y in zip(bs, ys):
            if y >= 0:
                B = b
                break
    return B

bmin(1024 * 128)

渐近估计 #

除了数值求解外,我们也可以通过渐近分析来得到一个解析的估计结果,这个估计比数值结果要小,本质上是d的解,但同样能够得出“b应该随着L增大而增大”的结论。

渐近估计的思路,是用积分代替求和:
fb(m)=d/21i=0cosmb2i/d10cosmbsdst=mbs=mmb1costtlnbdt


其中我们记
Ci(x)=xcosttdt

这是被前人研究过的三角积分(参考 Trigonometric integral ),利用这个记号,我们可以写出
fb(m)Ci(m)Ci(mb1)lnb

Ci(x)的图像长这样:

Ci(x)的图像【来自维基百科】

Ci(x)的图像【来自维基百科】

它的第一个零点是x0=0.6165,对于m1,可以看出|Ci(m)|1/2,所以其实Ci(m)相对来说是小项,对于渐近估计来说可以忽略,那么问题近似地变成了Ci(mb1)0对于m=1,2,,L恒成立,我们只需要让相应的mb1都落在[0,x0]区间内就可以实现,这意味着Lb1x0,即
bL/x02L


或者简单点b=O(L)。不出意料这个结果比精确的数值结果要小,因为它对应于d,无限个三角函数叠加会使得函数图像的震荡更少,看起来更加平稳(相比于有限的d),从而对于固定的bfb(m)的连续非负区间更长,或者反过来,对于固定的L,保持m=0,1,2,,L1fb(m)都非负的b更小。

相关思考 #

《Transformer升级之路:10、RoPE是一种β进制编码》中,我们将RoPE类比为一种β进制表示,其中β=b2/d,那么b1=βd/21正好是d/2β进制编码能够表示的最大数字,于是要表示0,1,2,,L1L个位置编码,至少有bL,这个朴素的类比再次给出了“b应该随着L增大而增大”的结论,其结果跟上一节的渐近分析结果更为接近。

另一方面,Meta最新发布的LLAMA3,训练长度为8192,但RoPE的底数选择了惊人的500000(5e5),这比前面的数值结果(8.4e4)还要大将近一个数量级,不管从哪个角度看,这个数值笔者都认为是偏大的,可能LLAMA3的这个底数本就是给更大文本长度预留的。但不论如何,更大的文本长度选择更大的RoPE底数,似乎已经成为了很多训练人员的共识。

其实不管是数值结果还是渐近估计,都只是一个参考值,实际上对于给定的L,一个相当大范围内的b都应该会有相近的效果。所以具体的数值都不重要,关键是原论文通过语义聚合的出发点和一系列推导,澄清了“b应该随着L增大而增大”的结论及其原理,这是笔者所认为的原论文的核心贡献。

此外,其实语义聚合的出发点和结论也可以用来解释Position Interpolation(PI)。刚才我们说了,同一个bfb(m)的连续非负区间是固定的,如果要使0,1,2,,L1都落在非负区间内,就需要随着L的增大而相应的增加b。但反过来,我们也可以不增加b,而是减少相邻位置的间隔(即位置ID改成0,1/k,2/k,),那么就可以在同样大小的非负区间内表示k倍的位置了,这便是语义聚合视角下的Position Interpolation。

部分旋转 #

RoPE提出于2021年,当时只有一篇中文博客,后来得到了EleutherAI组织的认可和实验,继而才逐渐向学术界推广。当时EleutherAI实验发现,如果只对部分维度加RoPE,会取得稍优的结果,相关内容可以参考这里这里这里,后来这个操作用到了它们的GPT-NeoX中。

当然,部分旋转还不是当前LLM的主流选择,但这不妨碍我们研究它,也许它未成为主流选择只是因为我们对它还不够了解。那为什么部分旋转反而可能会更优呢?笔者发现可以用本文的结论来一定程度上解释它。以只旋转一半维度为例,它在数学上等价于选择如下的θi
θi={b4i/d,i<d/40,id/4


此时我们有
d/21i=0cosmθi=d/41i=0(1+cosmb4i/d)0

也就是不论m,b如何,我们所期望的不等式(5)都自动成立,这意味着从本文的观点来看,部分旋转在赋予位置信息的同时有更好的语义聚合能力,这对模型的效果也许更加有利。同时,部分旋转对模型的长文本能力或许也更有利,因为不等式恒成立,所以按照本文的观点,不论长短文本训练都不用修改b

值得一提的是,DeepSeek提出的MLA也应用了部分旋转,虽然在MLA的原始推导中,部分旋转更多是为了整合RoPE的无奈之举,但结合以往的部分旋转实验结果来看,也许MLA的优异效果有部分旋转的一分功劳。

文章小结 #

本文简单介绍了论文《Base of RoPE Bounds Context Length》,它从语义聚合的期望性质讨论了RoPE的底数下界,由此指出更大的训练长度应该选择更大的底数,而不单单是为了配合“先短后长”的训练策略、继而利用NTK-RoPE来降低初始损失的折中选择。

转载到请包括本文地址:https://kexue.fm/archives/10122

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (May. 29, 2024). 《Transformer升级之路:18、RoPE的底数选择原则 》[Blog post]. Retrieved from https://kexue.fm/archives/10122

@online{kexuefm-10122,
        title={Transformer升级之路:18、RoPE的底数选择原则},
        author={苏剑林},
        year={2024},
        month={May},
        url={\url{https://kexue.fm/archives/10122}},
}