Transformer升级之路:20、MLA究竟好在哪里?
By 苏剑林 | 2025-05-04 | 14342位读者 |自从DeepSeek爆火后,它所提的Attention变体MLA(Multi-head Latent Attention)也愈发受到关注。MLA通过巧妙的设计实现了MHA与MQA的自由切换,使得模型可以根据训练和推理的不同特性(Compute-Bound or Memory-Bound)选择最佳的形式,尽可能地达到效率最大化。
诚然,MLA很有效,但也有观点认为它不够优雅,所以寻找MLA替代品的努力一直存在,包括我们也有在尝试。然而,经过一段时间的实验,我们发现很多KV Cache相同甚至更大的Attention变体,最终效果都不如MLA。这不得不让我们开始反思:MLA的出色表现背后的关键原因究竟是什么?
接下来,本文将详细介绍笔者围绕这一问题的思考过程以及相关实验结果。
观察 #
MLA提出自DeepSeek-V2,本文假设读者已经熟悉MLA,至少了解之前的博客《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA》所介绍的内容,因此MLA自身的细节将不会过多展开。
MLA的主要特点如下:
1、MLA在训练阶段是一个qk_head_dims=(128+64)、v_head_dims=128的MHA;
2、MLA在解码阶段是一个qk_head_dims=(512+64)、v_head_dims=512、KV-Shared的MQA;
3、MLA的[qc, qr]、[kc, kr]拼接,可以理解为一种Partial RoPE。
猜测 #
MHA、GQA常用的head_dims是128,而对于MLA来说,不管是从训练看的128+64,还是从推理看的512+64,都要大于128,再结合《突破瓶颈,打造更强大的Transformer》的经验,我们有:
猜测1: 增大head_dims是MLA好的关键之一。
另外,KV-Shared这个特性,可以在同等KV Cache大小下,增大GQA的head_dims或者num_groups,所以有:
猜测2: KV-Shared是MLA好的关键之一。
最后,此前有一些理论和实验显示Partial RoPE可能会对效果有正面帮助(参考《Transformer升级之路:18、RoPE的底数选择原则》),所以有
猜测3: Partial RoPE是MLA好的关键之一。
实验 #
现在我们通过实验逐一检验以上猜测。
设置 #
所有实验公共部分的超参数如下:
1、类似LLAMA3的Dense模型;
2、hidden_size=2048,num_layers=12,num_heads=16;
3、优化器是Muon,Attention部分per head更新;
4、训练长度为4096,总tokens数为16B,总训练步数为16k;
5、所有实验都是只改变Attention,所以参数量不会严格对齐。
Part I #
MLA的KV Cache大小是512+64,约等于GQA2-128(第一个数字是num_groups,第二个数字是head_dims),所以对比的baseline为GQA2-128和GQA1-256。为了验证Partial RoPE,我们增加了GQA1-256-PR,具体做法是将Q、K的256 dims分成192+64两部分,在64上加RoPE,192不加。
结果如下:
$$\begin{array}{c|ccc}
\hline
& \text{Params} & \text{Loss} & \text{Cache} \\
\hline
\text{MLA} & 894M & 2.721 & 576 \\
\text{GQA2-128} & 842M & 2.75 & 512 \\
\text{GQA1-256} & 943M & 2.72 & 512 \\
\text{GQA1-256-PR} & 943M & 2.711 & 512 \\
\hline
\end{array}$$
即
$$\text{GQA2-128} < \text{MLA} \lesssim \text{GQA1-256} < \text{GQA1-256-PR}$$
初步验证了增大head_dims和Partial RoPE的作用。这样看来,MLA的设计中,RoPE和NoPE拼接这部分看似无奈的设计,极有可能是它效果优异的关键原因!原论文声称MLA甚至优于MHA,大概率也是因为所对比的MHA的head_dims只有128。
Part II #
为了进一步验证增大head_dims的作用,我们另外跑了MHA、GQA2-192、MLA-256三个实验,MHA是head_dims=128的常规MHA,GQA2-192是直接增大GQA2的head_dims到192,MLA-256是将MLA的128+64提升到192+64,对照如下
$$\begin{array}{c|ccc}
\hline
& \text{Params} & \text{Loss} & \text{Cache} \\
\hline
\text{MHA} & 931M & 2.721 & 4096 \\
\text{MLA} & 894M & 2.721 & 576 \\
\text{MLA-256} & 989M & 2.705 & 576 \\
\text{GQA2-128} & 842M & 2.75 & 512 \\
\text{GQA2-192} & 899M & 2.729 & 768 \\
\text{GQA1-256} & 943M & 2.72 & 512 \\
\text{GQA1-256-PR} & 943M & 2.711 & 512 \\
\hline
\end{array}$$
可以看到,MHA总参数量更多,KV Cache更是7倍于MLA,但Loss才堪堪追平MLA,这跟DeepSeek-V2里边的结论接近。此外,GQA2-192优于GQA2-128,但不如GQA1-256;MLA的head_dims升到(192+64)后,相比(128+64)也还能进一步提升效果。这些现象都表明,增加head_dims远比增加num_groups更有效。
Part III #
接下来我们验证KV-Shared,即K、V共享全部或大部分dims。这里我们主要考虑的替代品是head_dims不超过256的GQA,并且控制KV Cache的总大小跟MLA接近,所以当KV-Shared时,我们可以至多可以考虑GQA2-256。
由于KV-Shared跟RoPE不完全兼容,参考MLA的做法,我们将256分成192+64两部分,其中
1、192部分不加RoPE,在K、V间共享;
2、64部分加RoPE,只用于K;
3、V另外再投影64 dims,concat到共享的192 dims上去。
这样一来,K、V的head_dims都是256,KV Cache总大小是(192+64+64)*2=640,略大于MLA的512+64=576,这个版本我们简记为“GQA2-(192+64)-S1”,其实“S1”是“Shared-1”的缩写。
Part IV #
另外一种KV-Shared的方案是:
1、192部分不加RoPE,在K、V间共享;
2、64部分加RoPE,同样在K、V间共享;
3、做Attention,由于V带RoPE,此时是绝对位置编码效果;
4、为了保证相对位置编码,将输出分成192+64两部分,64部分再加一次逆向RoPE。
这种做法是K、V完全共享,KV Cache大小是(192+64)*2=512,略小于MLA。这个版本我们称为“GQA2-(192+64)-S2”,“S2”是“Shared-2”的缩写,背后的原理是笔者新提出的VO-RoPE,参考《Transformer升级之路:19、第二类旋转位置编码》。
Part V #
另外,根据同样思路补了几个GQA4和GQA1的实验。所有实验结果汇总如下:
$$\begin{array}{c|ccc|c}
\hline
& \text{Params} & \text{Loss} & \text{Cache} & \text{备注} \\
\hline
\text{MLA} & 894M & 2.721 & 576 & \\
\text{MLA-256} & 989M & 2.705 & 576 & \\
\text{GQA2-(192+64)-S1} & 946M & 2.714 & 640 & \\
\text{GQA2-(192+64)-S2} & 943M & 2.708 & 512 & \text{引入VO-RoPE} \\
\text{GQA4-(64+64)-S2} & 842M & 2.738 & 512 & \\
\text{GQA4-(128+64)-S2} & 899M & 2.713 & 768 & \text{KV Cache最大} \\
\text{GQA1-(512+64)-S3} & 1171M & 2.677 & 576 & \text{head_dims最大} \\
\hline
\end{array}$$
这里“GQA1-(512+64)-S3”是按照MLA的推理形式实现的MQA,形式介乎S1与S2之间,它的主要特点是head_dims大。
结果解读:
1、KV-Shared的GQA自带Partial RoPE;
2、KV-Shared的GQA2-256,也能超过MLA;
3、VO-RoPE的引入,似乎有利于效果(S1 ≲ S2);
4、同等KV Cache下,head_dims越大越好;
5、GQA2-(192+64)-S2 略微超过 GQA1-256-PR;
6、GQA4-(128+64)-S2 的KV Cache最大,但效果不是最优,再次表明head_dims更关键。
关于KV-Shared,还有两点观察:
1、训练过程中,GQA1-256-PR前期是明显领先GQA2-(192+64)-S2,但后期被追平甚至略微反先,猜测GQA1-256-PR可能有后劲不足的嫌疑;
2、如果没有KV-Shared,GQA顶多是GQA1-256,也就是说head_dims顶天了256,但有KV-Shared的话,GQA可以做到GQA1-512-S,单纯从head_dims看,KV-Shared天花板更高。
Part VI #
由于没有严格对齐参数量,可能读者会有“到底是增加参数量还是增加head_dims更本质”的疑虑,所以这里补充几个对齐参数量的实验。
这里考虑的对齐参数量的方式有三种:
1、double-heads:以“GQA2-128 vs GQA1-256”为例,将GQA2-128的num_heads翻倍,可以让GQA2-128的参数量跟GQA1-256相同;
2、缩减MLP:缩小MLP(SwiGLU)的intermediate_size,也可以使得GQA1-256的参数量跟GQA2-128大致相同;
3、Q&O LoRA:GQA的主要参数量来自Query和Output的投影矩阵,对这两个矩阵改用LoRA,也可以降低GQA1-256的参数量。
实验结果如下:
$$\begin{array}{c|ccc|ccc}
\hline
& \text{Params} & \text{Loss} & \text{Cache} & \text{num_heads} & \text{intermediate_size} & \text{qo_lora} \\
\hline
\text{MLA} & 894M & 2.721 & 576 & 16 & 5456 & \text{No}\\
\hline
\text{GQA2-128} & 842M & 2.75 & 512 & 16 & 5456 & \text{No}\\
\text{GQA1-256} & 943M & 2.72 & 512 & 16 & 5456 & \text{No}\\
\hline
\text{GQA2-128} & 943M & 2.723 & 512 & \color{red}{32} & 5456 & \text{No} \\
\text{GQA1-256} & 843M & 2.747 & 512 & 16 & \color{red}{4096} & \text{No} \\
\text{GQA1-256} & 842M & 2.726 & 512 & 16 & 5456 & \color{red}{\text{Yes}} \\
\hline
\text{GQA4-(64+64)-S2} & 842M & 2.738 & 512 & 16 & 5456 & \text{No} \\
\text{GQA2-(192+64)-S2} & 943M & 2.708 & 512 & 16 & 5456 & \text{No} \\
\hline
\text{GQA4-(64+64)-S2} & 943M & 2.711 & 512 & \color{red}{32} & 5456 & \text{No} \\
\text{GQA2-(192+64)-S2} & 843M & 2.733 & 512 & 16 & \color{red}{4096} & \text{No} \\
\text{GQA2-(192+64)-S2} & 842M & 2.708 & 512 & 16 & 5456 & \color{red}{\text{Yes}} \\
\hline
\end{array}$$
结果主要分三块:
1、heads翻倍相比head_dims翻倍,loss稳定差0.003左右;
2、缩小MLP比head_dims减半,loss稳定优0.004左右;
3、Q&O LoRA性能损失最小,可以实现head_dims翻倍但参数量不增,且loss明显降。
结论:如果从增加参数量角度看,增大head_dims可能是效果增益较大的方向,配合Q&O LoRA可以实现参数量几乎不增,但收益仍相当。
小结 #
初步结论是:
1、增大head_dims收益最大;
2、Partial RoPE对Loss也有一定帮助;
3、KV-Shared应该也有一定作用。
这样看来,此前我们一直在head_dims=128下找MLA的替代品,感觉是起点就先天不足了,难怪一直比不上MLA。要想追平MLA,head_dims应该要192起步了,并辅以Partial RoPE。至于KV-Shared,也可能有用,但应该还需要更大规模的验证。
意义 #
其实这里边的意义,就看我们换掉MLA的决心有多强。
假设 GQA2-(192+64)-S2 可以替代MLA,但MLA也可以升到256,目前看来 GQA2-(192+64)-S2 比不上 MLA-256 。那么换掉MLA的唯二好处是:
1、结构更简单,可以方便加QK-Norm;
2、解码阶段的head_dims由512+64变成了256,同时num_groups变为2,可以TP。
转载到请包括本文地址:https://kexue.fm/archives/10907
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (May. 04, 2025). 《Transformer升级之路:20、MLA究竟好在哪里? 》[Blog post]. Retrieved from https://kexue.fm/archives/10907
@online{kexuefm-10907,
title={Transformer升级之路:20、MLA究竟好在哪里?},
author={苏剑林},
year={2025},
month={May},
url={\url{https://kexue.fm/archives/10907}},
}
May 4th, 2025
感谢苏神文章。有个小问题想请教苏神:
前提(MLA参考图):
q k 计算 attention的时候都包含两个部分没有 rope 部分(计为 nope)以及有 rope 部分。以 q 为例,$q_{nope}$ and $q_{rope}$,同理 k 也有这两个部分。
符号解释: $h_t$ 表示原始输出,$q_c$ 表示 q 压缩向量,$kv_{c}$ 表示 kv 压缩向量。
疑问:问什么在论文中,q 的 rope 和 nope 编码内容来自于 $q_c$,而 k 的 $k_{nope}$ 来自于 $h_t$,$k_{rope}$ 来自于 $kv_c$?
因为$q_{rope}$是head-wise的,而$k_{rope}$则是head共享的,所有head共享的部分($q_c$, $kv_c$, $k_{rope}$)均是低秩的(由$h_t$导出)。(btw你最后一句话写反了)
有道理,我理解你的意思了~感谢(是写反了)
接下来想再问一个:那意思是只要是 head-wise 的内容就需要从 compressed 结果中变换而来吗?
$k_{nope}$来自$h_t$的话,就达不到节省kv cache的效果了。
上面确实写反了:应该是 $k_{rope}$ 来自于 $h_t$, $k_{nope}$ 来自于 $kv_{c}$
苏神可能没有理解我的疑问:$k_{nope}$ 来自于 $kv_c$ 这个我没有疑问,cache kv_c 可以节约显存;
我的问题重点是:q 又不需要 cache,$k_{rope}$ 取自 $h_t$,为什么 $q_{rope}$ 不从 $h_t$ 变换而来呢?
上面 Namoe 有按照 head wise 回答,但我不理解这里的假设是什么?
你可以从推理角度来理解,推理时是MQA,key是c与kr的拼接,c、kr都来自h,这样就很舒服了。至于query,此时qc、qr都来自q_lora,总比一个来自q_lora、另一个来自h要舒服。其实没有太多原因,主要看审美。
May 4th, 2025
Partial RoPE 可能有效的一个解释由https://arxiv.org/abs/2410.06205给出了,论文指出在d_model上的开头部分(对应$\frac{1}{10000^{-2h/d}}$的h比较小的时候)旋转角度太大,导致qk乘积发生较大变化,实际上可能破坏了这部分的语义信息,变成了纯粹的相对位置信息。
论文table2也给出了p-RoPE(去掉RoPE旋转角度较大的部分dim)和RoPE-partial(只对部分head施加RoPE)的对比结果,看起来p-RoPE会好一点。不确定苏神文中的Partial RoPE是怎么实现的($\frac{1}{10000^{-2h/d}}$的h是从0开始,还是从192开始,check了一下DeepSeek-V3的实现是前者)。
想问问,很奇怪我看这篇文章提出的p_rope是发现移除low frequency(旋转角度小的部分)比移除 high frequency更加有效。那这样就和MLA 和上面的实验都相反。大佬能解答一下吗? 是不是我的理解有误。
其实完全不矛盾,p-RoPE的结论可以理解成,需要一部分dim的RoPE系数是无穷大,使得这部分不受到旋转的影响。p-RoPE和苏神的操作分别让低频部分和高频部分不旋转,都有改善,说明部分不旋转是关键要素之一。至于是改低频部分还是改高频部分,还是得看实验效果
Partial RoPE是从192开始。p-RoPE是去掉低频部分吧,是旋转角度较小部份,这跟Partial RoPE本身是异曲同工的,都是让更多dims变成NoPE,不矛盾
May 6th, 2025
是否将底数调整一下就可以了?不过调整之后后面的旋转角度太小,也类似partial ROPE了
对,增大base效果也类似。
May 6th, 2025
要是head_dim这么重要的话,是不是可以考虑减少头的数量来腾出更多dim给每个head?
有点奇怪..为了更好的效果(因为n个理由)所以发明了多头注意力机制,然后现在为了腾出更多head_dim又要把多头减少吗?
这个感觉也不好说吧,当时发明的时候也没有考虑过kvcache的问题,所以也不好说现在调大会不会就是更好的方案,虽然带来了更多的计算量,但是可能相较于带来的提升是可以接受的
在目前的flash attention背景下,head_dims一般来说只能做到256,所以也没多少挪腾空间。其实没必要减少num_heads,因为参数量多或者计算量大都不是问题,主要问题是把这些参数量或计算量放到哪里会让模型收益最大,本文只是提出一个可能性。
May 7th, 2025
> GQA的主要参数量来自Query和Output的投影矩阵,对这两个矩阵改用LoRA
意思是把 $W_q \in \mathbb{R}^{H\times d}$ 改写为 $W_q^A\cdot W_q^B$ 其中 $W_q^A \in \mathbb{R}^{H\times c}, W_q^B \in \mathbb{R}^{c\times d}$ 的形式吗?
May 8th, 2025
苏神大佬,这里的增大 head dim 有效会不会是模型参数比较小。 因此各个的head dim 不够大,multihead 导致了每个头的维度太小了。想知道如果扩到3B, 7B, 72B也会有类似的效果吗?
我相信增大head_dims的普遍有效的,唯一疑虑的是它是否效果梯度最大的方向。
May 10th, 2025
感谢分享。想问一下苏神对于RWKV有什么看法,可以出一期相关的文章嘛
暂时没什么看法,对当前linear的关注并不多。
May 14th, 2025
https://papers.cool/arxiv/2410.05258
这篇论文提出了一种差分注意力,效果看上去不错。
论文作者认为可以抵消注意力“噪声”,这种差分似乎和裁剪类似,将不相关位置的注意力降到很低。
差分注意力每个head实际也用到了两倍的head_dim,head数量减半,也可能是单个head使用的参数量上升带来的收益。
这个方法苏神怎么看
听说很多人验证过,scaleup上去之后就没收益了。
May 15th, 2025
苏老师,看到你之前发的一个博客其实讨论过提高 head_dim 能够涨点的分析,看来核心确实是要增加 head_dim :
https://spaces.ac.cn/archives/7325
谢谢,我还真忘记这茬了,已经补到正文里边去。