缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
By 苏剑林 | 2024-05-13 | 87374位读者 |前几天,幻方发布的DeepSeek-V2引起了大家的热烈讨论。首先,最让人哗然的是1块钱100万token的价格,普遍比现有的各种竞品API便宜了两个数量级,以至于有人调侃“这个价格哪怕它输出乱码,我也会认为这个乱码是一种艺术”;其次,从模型的技术报告看,如此便宜的价格背后的关键技术之一是它新提出的MLA(Multi-head Latent Attention),这是对GQA的改进,据说能比GQA更省更好,也引起了读者的广泛关注。
接下来,本文将跟大家一起梳理一下从MHA、MQA、GQA到MLA的演变历程,并着重介绍一下MLA的设计思路。
MHA #
MHA(Multi-Head Attention),也就是多头注意力,是开山之作《Attention is all you need》所提出的一种Attention形式,可以说它是当前主流LLM的基础工作。在数学上,多头注意力MHA等价于多个独立的单头注意力的拼接,假设输入的(行)向量序列为$\boldsymbol{x}_1,\boldsymbol{x}_2,\cdots,\boldsymbol{x}_l$,其中$\boldsymbol{x}_i\in\mathbb{R}^d$,那么MHA可以形式地记为
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
简单起见,这里省略了Attention矩阵的缩放因子。实践上,常见的设置是$d_k = d_v = d / h$,对于LLAMA2-7b有$d=4096, h=32, d_k = d_v = 128$,LLAMA2-70b则是$d=8192,h=64, d_k = d_v = 128$
由于这里只考虑了主流的自回归LLM所用的Causal Attention,因此在token by token递归生成时,新预测出来的第$t+1$个token,并不会影响到已经算好的$\boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}$,因此这部分结果我们可以缓存下来供后续生成调用,避免不必要的重复计算,这就是所谓的KV Cache。
而后面的MQA、GQA、MLA,都是围绕“如何减少KV Cache同时尽可能地保证效果”这个主题发展而来的产物。
瓶颈 #
一个自然的问题是:为什么降低KV Cache的大小如此重要?
众所周知,一般情况下LLM的推理都是在GPU上进行,单张GPU的显存是有限的,一部分我们要用来存放模型的参数和前向计算的激活值,这部分依赖于模型的体量,选定模型后它就是个常数;另外一部分我们要用来存放模型的KV Cache,这部分不仅依赖于模型的体量,还依赖于模型的输入长度,也就是在推理过程中是动态增长的,当Context长度足够长时,它的大小就会占主导地位,可能超出一张卡甚至一台机(8张卡)的总显存量。
在GPU上部署模型的原则是:能一张卡部署的,就不要跨多张卡;能一台机部署的,就不要跨多台机。这是因为“卡内通信带宽 > 卡间通信带宽 > 机间通信带宽”,由于“木桶效应”,模型部署时跨的设备越多,受设备间通信带宽的的“拖累”就越大,事实上即便是单卡H100内SRAM与HBM的带宽已经达到了3TB/s,但对于Short Context来说这个速度依然还是推理的瓶颈,更不用说更慢的卡间、机间通信了。
所以,减少KV Cache的目的就是要实现在更少的设备上推理更长的Context,或者在相同的Context长度下让推理的batch size更大,从而实现更快的推理速度或者更大的吞吐总量。当然,最终目的都是为了实现更低的推理成本。
要想更详细地了解这个问题,读者可以进一步阅读《FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness》、《A guide to LLM inference and performance》、《LLM inference speed of light》等文章,这里就不继续展开了(主要是笔者水平也有限,唯恐说多错多)。
MQA #
MQA,即“Multi-Query Attention”,是减少KV Cache的一次非常朴素的尝试,首次提出自《Fast Transformer Decoding: One Write-Head is All You Need》,这已经是2019年的论文了,这也意味着早在LLM火热之前,减少KV Cache就已经是研究人员非常关注的一个课题了。
MQA的思路很简单,直接让所有Attention Head共享同一个K、V,用公式来说,就是取消MHA所有的$\boldsymbol{k},\boldsymbol{v}$的上标${}^{(s)}$:
\begin{equation}\require{cancel}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{#ccc}{\smash{\bcancel{(s)}}}} ,\boldsymbol{v}_{\leq t}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{v}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}} = \boldsymbol{x}_i\boldsymbol{W}_k^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}} = \boldsymbol{x}_i\boldsymbol{W}_v^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
使用MQA的模型包括PaLM、StarCoder、Gemini等。很明显,MQA直接将KV Cache减少到了原来的$1/h$,这是非常可观的,单从节省显存角度看已经是天花板了。
效果方面,目前看来大部分任务的损失都比较有限,且MQA的支持者相信这部分损失可以通过进一步训练来弥补回。此外,注意到MQA由于共享了K、V,将会导致Attention的参数量减少了将近一半,而为了模型总参数量的不变,通常会相应地增大FFN/GLU的规模,这也能弥补一部分效果损失。
GQA #
然而,也有人担心MQA对KV Cache的压缩太严重,以至于会影响模型的学习效率以及最终效果。为此,一个MHA与MQA之间的过渡版本GQA(Grouped-Query Attention)应运而生,出自论文《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》,是去年的工作。
事后看来,GQA的思想也很朴素,它就是将所有Head分为$g$个组($g$可以整除$h$),每组共享同一对K、V,用数学公式表示为
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{red}{(\lceil sg/h\rceil)}} ,\boldsymbol{v}_{\leq t}^{\color{red}{(\lceil sg/h\rceil)}}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{(\lceil sg/h\rceil)}}{}^{\top}\right)\boldsymbol{v}_i^{\color{red}{(\lceil sg/h\rceil)}}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{red}{(\lceil sg/h\rceil)}}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{\color{red}{(\lceil sg/h\rceil)}} = \boldsymbol{x}_i\boldsymbol{W}_k^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{\color{red}{(\lceil sg/h\rceil)}} = \boldsymbol{x}_i\boldsymbol{W}_v^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{\color{red}{(\lceil sg/h\rceil)}}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
这里的$\lceil\cdot\rceil$是上取整符号。GQA提供了MHA到MQA的自然过渡,当$g=h$时就是MHA,$g=1$时就是MQA,当$1 < g < h$时,它只将KV Cache压缩到$g/h$,压缩率不如MQA,但同时也提供了更大的自由度,效果上更有保证。GQA最知名的使用者,大概是Meta开源的LLAMA2-70B,以及LLAMA3全系列,此外使用GQA的模型还有TigerBot、DeepSeek-V1、StarCoder2、Yi、ChatGLM2、ChatGLM3等,相比使用MQA的模型更多(ChatGLM虽然在它的介绍中说自己是MQA,但实际是$g=2$的GQA)。
在llama2/3-70B中,GQA的$g=8$,其他用了GQA的同体量模型基本上也保持了这个设置,这并非偶然,而是同样出于推理效率的考虑。我们知道,70B这个体量的模型,如果不进行极端的量化,那么不可能部署到单卡(A100/H100 80G)上。单卡不行,那么就能单机了,一般情况下一台机可以装8张卡,刚才我们说了,Attention的每个Head实际上是独立运算然后拼接起来的,当$g=8$时,正好可以每张卡负责计算一组K、V对应的Attention Head,这样可以在尽可能保证K、V多样性的同时最大程度上减少卡间通信。
MLA #
有了MHA、MQA、GQA的铺垫,我们理解MLA(Multi-head Latent Attention)就相对容易一些了。DeepSeek-V2的技术报告里是从低秩投影的角度引入MLA的,以至于有部分读者提出“为什么LoRA提出这么久了,直到MLA才提出对KV Cache低秩分解的做法”之类的疑问。
然而,笔者认为低秩投影这个角度并不贴近本质,因为要说低秩投影的话,事实上只要我们将GQA的所有K、V叠在一起,就会发现GQA也相当于在做低秩投影:
\begin{equation}\underbrace{\left[\boldsymbol{k}_i^{(1)},\cdots,\boldsymbol{k}_i^{(g)},\boldsymbol{v}_i^{(1)},\cdots,\boldsymbol{v}_i^{(g)}\right]}_{\boldsymbol{c}_i\in\mathbb{R}^{g(d_k+d_v)}} = \boldsymbol{x}_i \underbrace{\left[\boldsymbol{W}_k^{(1)},\cdots,\boldsymbol{W}_k^{(g)},\boldsymbol{W}_v^{(1)},\cdots,\boldsymbol{W}_v^{(g)}\right]}_{\boldsymbol{W}_c\in\mathbb{R}^{d\times g(d_k+d_v)}}\end{equation}
这里我们将所有$\boldsymbol{k}_i^{(s)},\boldsymbol{v}_i^{(s)}$拼在一起记为$\boldsymbol{c}_i$,相应的投影矩阵也拼在一起记为$\boldsymbol{W}_c$,注意到一般都有$d_c = g(d_k+d_v) < d$,所以$\boldsymbol{x}_i$到$\boldsymbol{c}_i$的变换就是一个低秩投影。所以,MLA的本质改进不是低秩投影,而是低秩投影之后的工作。
Part 1 #
GQA在投影之后做了什么呢?首先它将向量对半分为两份分别作为K、V,然后每一份又均分为$g$份,每一份复制$h/g$次,以此来“凑”够$h$个Attention Head所需要的K、V。我们知道分割、复制都是简单的线性变换,所以MLA的第一个想法是将这些简单的线性变换换成一般的线性变换,以增强模型的能力:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d_c\times d_k} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt]
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c}
\end{gathered}
\end{equation}
然而,理论上这样是能增加模型能力,但别忘了GQA的主要目的是减少KV Cache,出于节省计算和通信成本的考虑,我们一般会缓存的是投影后的$\boldsymbol{k}_i, \boldsymbol{v}_i$而不是投影前的$\boldsymbol{c}_i$或$\boldsymbol{x}_i$,而MLA的这个做法,通过不同的投影矩阵再次让所有的K、V Head都变得各不相同,那么KV Cache的大小就恢复成跟MHA一样大了,违背了GQA的初衷。
对此,MLA发现,我们可以结合Dot-Attention的具体形式,通过一个简单但不失巧妙的恒等变换来规避这个问题。首先,在训练阶段还是照常进行,此时优化空间不大;然后,在推理阶段,我们利用
\begin{equation}\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\right){}^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}\right)\boldsymbol{c}_i^{\top} \end{equation}
这意味着推理阶段,我们可以将$\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$合并起来作为Q的投影矩阵,那么$\boldsymbol{c}_i$则取代了原本的$\boldsymbol{k}_i$,同理,在$\boldsymbol{o}_t$后面我们还有一个投影矩阵,于是$\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}$的$\boldsymbol{W}_v^{(s)}$也可以吸收到后面的投影矩阵中去,于是等效地$\boldsymbol{v}_i$也可以用$\boldsymbol{c}_i$代替,也就是说此时KV Cache只需要存下所有的$\boldsymbol{c}_i$就行,而不至于存下所有的$\boldsymbol{k}_i^{(s)}$、$\boldsymbol{v}_i^{(s)}$。注意到$\boldsymbol{c}_i$跟${}^{(s)}$无关,也就是说是所有头共享的,即MLA在推理阶段它可以恒等变换为一个MQA。
再次强调,本文的主题是一直都是减少KV Cache,那到目前为止,MLA做到了什么呢?答案是通过不同的投影矩阵来增强了GQA的能力,并且推理时可以保持同样大小的KV Cache。那么反过来,如果我们只需要跟GQA相近的能力,那么是不是就可以再次减少KV Cache了?换言之,$d_c$没必要取$g(d_k+d_v)$,而是取更小的值(DeepSeek-V2取了512),从而进一步压缩KV Cache,这就是MLA的核心思想。
(注:这里有一个细节,就是$\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$合并成一个矩阵的恒等变换,理论上只有在无限精度下才成立,实际上如果我们使用单精度尤其是BF16的话,经过变换后的精度损失往往还是挺明显的,经过多层累积后可能放大到比较可观的程度,这里可能要根据实际误差看要不要做一些后处理。)
Part 2 #
一切似乎都很完美,看上去一个又好又省的理想设计就要出炉了。不过别急,当我们再深入思考一下就会发现,到目前为止的MLA有一个难以绕开的缺陷——不兼容RoPE(旋转位置编码)。
刚才我们说了,MLA之所以能保持跟GQA一样大小的KV Cache,其关键一步是“将$\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$合并成一个(跟位置无关的)矩阵作为Q的投影矩阵”,但如果加了RoPE的话,这一步就无法实现了。这是因为RoPE是一个跟位置相关的、$d_k\times d_k$的分块对角矩阵$\boldsymbol{\mathcal{R}}_m$,满足$\boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n}$,MLA加入RoPE之后会让$\boldsymbol{W}_q^{(s)}\boldsymbol{W}_k^{(s)}{}^{\top}$之间多插入了一项$\boldsymbol{\mathcal{R}}_{t-i}$:
\begin{equation}
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\quad,\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i} \\
\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top} = \left(\boldsymbol{x}_t\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_t}\right) \left(\boldsymbol{c}_i\boldsymbol{W}_k^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right){}^{\top} = \boldsymbol{x}_t\left(\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_{t-i}}\boldsymbol{W}_k^{(s)}{}^{\top}\right)\boldsymbol{c}_i^{\top} \end{equation}
这里的$\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_{t-i}}\boldsymbol{W}_k^{(s)}{}^{\top}$就无法合并为一个固定的投影矩阵了(跟位置差$t-i$相关),从而MLA的想法无法结合RoPE实现。
前段时间,笔者也很荣幸跟DeepSeek团队讨论过这个问题,但这个问题可以说非常本质,所以当时笔者实际上也没能提出什么有效的建议。最简单的方式是放弃RoPE,换用其他基于Attention Bias的位置编码,如ALIBI,但DeepSeek的实验显示它明显不如RoPE(注意,MLA不是不能加RoPE,而是加了RoPE之后无法用恒等变换技巧来减少KV Cache),笔者也提议过换Sandwich,它不像ALIBI单调衰减到负无穷,估计效果会好些,但感觉是治标不治本。还有一个折中的办法是将$\boldsymbol{q}_i$的输入也改为$\boldsymbol{c}_i$,然后RoPE加在$\boldsymbol{c}_i$之后,即
\begin{equation}\boldsymbol{q}_i^{(s)} = \boldsymbol{c}_i\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\boldsymbol{W}_q^{(s)},\quad\boldsymbol{k}_i^{(s)} = \boldsymbol{c}_i\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\boldsymbol{W}_k^{(s)}\end{equation}
这样$\boldsymbol{\mathcal{R}}_i$就可以吸收到$\boldsymbol{c}_i$中去,但这样就没有$\boldsymbol{\mathcal{R}}_m\boldsymbol{\mathcal{R}}_n^{\top}=\boldsymbol{\mathcal{R}}_{m-n}$的运算了,此时的RoPE不再是通过绝对位置实现相对位置,而单纯是在Q、K上加绝对位置,让模型自己想办法提炼相对位置信息。
最后发布的MLA,采取了一种混合的方法——每个Attention Head的Q、K新增$d_r$个维度用来添加RoPE,其中K新增的维度每个Head共享:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \left[\boldsymbol{x}_i\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{qr}^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d\times d_r}\\
\boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt]
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c}
\end{gathered}
\end{equation}
这样一来,没有RoPE的维度就可以重复“Part 1”的操作,在推理时KV Cache只需要存$\boldsymbol{c}_i$,新增的带RoPE的维度就可以用来补充位置信息,并且由于所有Head共享,所以也就只有在K Cache这里增加了$d_r$个维度,原论文取了$d_r = d_k / 2 = 64$,相比原本的$d_c=512$,增加的幅度不大。
Part 3 #
最后有一个细节,就是MLA的最终版本,还将Q的输入也改为了低秩投影形式,这与减少KV Cache无关,主要是为了减少训练期间参数量和相应的梯度(原论文说的是激活值,个人表示不大理解)所占的显存:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k + d_r},\quad \boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r}\\
\boldsymbol{k}_i^{(s)} = \left[\boldsymbol{c}_i\boldsymbol{W}_{kc}^{(s)}, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_k+d_r},\quad \boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k}, \boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{c}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_c\times d_v} \\[10pt]
\boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\
\end{gathered}
\end{equation}
注意$\boldsymbol{k}_i^{(s)}$中的第二项,带RoPE的部分,其输入还是$\boldsymbol{x}_i$而不是$\boldsymbol{c}_i$,这里保持了原论文的设置,不是笔误,$d_c'$原论文的取值是1536,跟$d_c=512$不同。同时,我们把带RoPE的MHA放在下面,方便大家对比:
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}, \boldsymbol{o}_t^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{(s)} ,\boldsymbol{v}_{\leq t}^{(s)}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)\boldsymbol{v}_i^{(s)}}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{(s)}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_q^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_q^{(s)}\in\mathbb{R}^{d\times d_k}\\
\boldsymbol{k}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_k^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\in\mathbb{R}^{d_k},\quad \boldsymbol{W}_k^{(s)}\in\mathbb{R}^{d\times d_k} \\
\boldsymbol{v}_i^{(s)} = \boldsymbol{x}_i\boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d_v},\quad \boldsymbol{W}_v^{(s)}\in\mathbb{R}^{d\times d_v}
\end{gathered}
\end{equation}
可以发现,其实在训练阶段,除了多了一步低秩投影以及只在部分维度加RoPE外,MLA与Q、K的Head Size由$d_k$换成$d_k + d_r$的MHA基本无异。
推理阶段的MLA则改为
\begin{equation}
\begin{gathered}
\boldsymbol{o}_t = \left[\boldsymbol{o}_t^{(1)}\boldsymbol{W}_v^{(1)}, \boldsymbol{o}_t^{(2)}\boldsymbol{W}_v^{(2)}, \cdots, \boldsymbol{o}_t^{(h)}\boldsymbol{W}_v^{(h)}\right] \\[10pt]
\boldsymbol{o}_t^{(s)} = Attention\left(\boldsymbol{q}_t^{(s)}, \boldsymbol{k}_{\leq t}^{\color{#ccc}{\smash{\bcancel{(s)}}}} ,\boldsymbol{c}_{\leq t}\right)\triangleq\frac{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)\boldsymbol{c}_i}{\sum_{i\leq t}\exp\left(\boldsymbol{q}_t^{(s)} \boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}}{}^{\top}\right)} \\[15pt]
\boldsymbol{q}_i^{(s)} = \left[\boldsymbol{c}_i'\boldsymbol{W}_{qc}^{(s)}\boldsymbol{W}_{kc}^{(s)}{}^{\top}, \boldsymbol{c}_i'\boldsymbol{W}_{qr}^{(s)}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c + d_r}\\
\boldsymbol{k}_i^{\color{#ccc}{\smash{\bcancel{(s)}}}} = \left[\boldsymbol{c}_i, \boldsymbol{x}_i\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\color{#3ce2f7}{\boldsymbol{\mathcal{R}}_i}\right]\in\mathbb{R}^{d_c+d_r}\\
\boldsymbol{W}_{qc}^{(s)}\in\mathbb{R}^{d_c'\times d_k},\boldsymbol{W}_{kc}^{(s)}\in\mathbb{R}^{d_c\times d_k},\boldsymbol{W}_{qr}^{(s)}\in\mathbb{R}^{d_c'\times d_r},\boldsymbol{W}_{kr}^{\color{#ccc}{\smash{\bcancel{(s)}}}}\in\mathbb{R}^{d\times d_r} \\[10pt]
\boldsymbol{c}_i' = \boldsymbol{x}_i \boldsymbol{W}_c'\in\mathbb{R}^{d_c'},\quad \boldsymbol{W}_c'\in\mathbb{R}^{d\times d_c'} \\
\boldsymbol{c}_i = \boldsymbol{x}_i \boldsymbol{W}_c\in\mathbb{R}^{d_c},\quad \boldsymbol{W}_c\in\mathbb{R}^{d\times d_c} \\
\end{gathered}
\end{equation}
此时Q、K的Head Size变成了$d_c + d_r$,V的Head Size 则变成了$d_c$,按照原论文的设置,这是$d_k$、$d_v$的4倍。所以实际上MLA在推理阶段做的这个转换,虽然能有效减少KV Cache,但其推理的计算量是增加的。
那为什么还能提高推理效率呢?这又回到“瓶颈”一节所讨论的问题了,我们可以将LLM的推理分两部分:第一个Token的生成(Prefill)和后续每个Token的生成(Generation),Prefill阶段涉及到对输入所有Token的并行计算,然后把对应的KV Cache存下来,这部分对于计算、带宽和显存都是瓶颈,MLA虽然增大了计算量,但KV Cache的减少也降低了显存和带宽的压力,大家半斤八两;但是Generation阶段由于每步只计算一个Token,实际上它更多的是带宽瓶颈和显存瓶颈,因此MLA的引入理论上能明显提高Generation的速度。
还有一个细节充分体现了这个特性。一般的LLM架构参数满足$h \times d_k = d$,即num_heads * head_size = hidden_size,但DeepSeek-V2不一样,它$d_k=128,d=5120$,但$h=128$,是一般设置的3倍!这是因为MLA的KV Cache大小跟$h$无关,增大$h$只会增加计算量和提升模型能力,但不会增加KV Cache,所以不会带来速度瓶颈。
小结 #
本文简单概述了多头注意力的演变历程,特别是从MHA向MQA、GQA,最终到MLA的变化理念,最后详细展开了对MLA的介绍。在本文中,MLA被视为GQA的一般化,它用投影矩阵的方式替代了GQA的分割、重复,并引入了一个恒等变换技巧来可以进一步压缩KV Cache,同时采用了一种混合方法来兼容RoPE。总的来说,MLA称得上是一种非常实用的注意力变体。
转载到请包括本文地址:https://kexue.fm/archives/10091
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (May. 13, 2024). 《缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA 》[Blog post]. Retrieved from https://kexue.fm/archives/10091
@online{kexuefm-10091,
title={缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA},
author={苏剑林},
year={2024},
month={May},
url={\url{https://kexue.fm/archives/10091}},
}
May 15th, 2024
苏老师您好,我对于将 Q 的输入也改为低秩投影形式这个地方也有一些疑惑。
其中 $c_{i}^{'} = x_{i}W_{c}^{'}$ 中的 $W_{c}^{'}$ 完全可以和后续计算 $q_{i}^{(s)}$ 时用到的 $W_{qc}^{(s)}$ 和 $W_{qr}^{(s)}$ 应用矩阵结合律合并,这种把一个矩阵拆成两个矩阵的乘积的做法有什么意义吗?
低秩分解,最明显的作用就是减少参数量。
是的。但是按照论文里的数据 $d = 5120, d_{c}' = 1536, d_{k} = 128$:
不分解时,$W_{qc}^{(s)}\in\mathbb{R}^{d\times d_{k}}$,对应的参数量和计算量为 $d\times d_{k}$;
分解时,$W_{c}'\in\mathbb{R}^{d\times d_{c}'},W_{qc}^{(s)}\in\mathbb{R}^{d_{c}'\times d_{k}}$,对应的参数量和计算量为 $d\times d_{c}' + d_{c}'\times d_{k}$,反而增加了。
如果拆出的维度不比原始矩阵的两个维度小,那么这个做似乎没有收益。是我哪里理解错了吗?
别忘了还有128个head(即$n_h=128$),每个head的投影矩阵都是独立的,所以不分解时是$d\times (d_{k}\times n_h)$,分解时是$d\times d_{c}' + d_{c}'\times (d_{k}\times n_h)$
您说的对,对于参数量和计算量而言如果考虑到 128 个 head 低秩分解后确实都减少了(减少为原来的 40% 左右)。但是这样对于论文中提到的减少 train 中的 activation memory 似乎还是没有帮助。
如果不使用低秩分解,为了在反向传播时计算 $W_Q$ 的梯度需要存储输入 $x$ 作为 activation memory;如果使用低秩分解,为了计算 $W_{c}'$ 的梯度依然需要存储输入 $x$,除此之外,为了计算 $W_{qc}^{(s)}$ 和 $W_{qr}^{(s)}$ 的梯度还需要额外存储压缩后的 $c_{i}'$。
这样看来,需要的 activation memory 反而增加了?难道说低秩投影矩阵 $W_{c}'$ 不属于可训练参数?
activation memory方面,我们普遍认为是结合了block-wise的recompute实现来减少activation,具体细节未知。
recompute的使用和注意力机制的设计没有因果关系呀,MHA也可以recompute,都能节省训练显存
@TomFoxxxx|comment-25037
但不同算法、不同recompute细节的成本不一样。
May 16th, 2024
苏神,你好,请问一下,用Qwen1.5-7B-Chat在做推理的时候,为什么会出现同一个问题,结果不一致的情况,temperature没有修改,设置为0。以及为什么在推理的时候要用到set_seed这个参数呀,seed不是在训练的时候用的吗?
我不了解Qwen~
请问一下,在做量化模型推理结果验证对比的时候,是不是一定要保证量化前后模型的seed参数是一致的?
需要保持一致。
seed只是生成的随机种子,用在哪里都可以,训练的时候随机初始化参数、随机选择batch算梯度等,推理的时候在有效的top-k范围随机采样token等,只要涉及随机,就有seed的用处。
1. 结果不一致的情况,一定要检查代码确认没使用采样(vllm中使用时temperature=0可以保证),如果不清楚不一致的原因,记得把模型输出的log记录下来,就是 model.generate 这个结果全部打印,日志默认会保存累积对数softmax的值,对比值大小就能看出差异在哪,采样参数有个 logprobs,这个是记录new_token对数概率值的前几维,设置的越大溯源越方便。
2. set_seed 在vllm中有两处地方,一个是服务启动时的设置(目前不清楚这个seed作用),一个是采样时的参数(保证采样可复现),采样是按照概率分布选取index,这个index还原为vocab就是推理生成的 token,这时候的 set_seed 就是保证采样的结果一致性。可惜官方仓库说过,即便设置了 set_seed,结果存在微小差异也是正常的。如果有强迫症,可以试下 min_p 这个参数。
May 20th, 2024
在式(4)中,ci是ki和vi沿着一个方向拼成一个超长的向量吗?所以ci的维度是1×g(dk+dv)
是的
May 29th, 2024
latent vector变换后做广播其实是MQA的变种吧,在Lite模型的源码中就是这样写的。
说是GQA的变种也没毛病 都可以解释通的 看待的角度不同
June 8th, 2024
苏神你好,我理解deepseek的这篇文章是通过牺牲效果(在q,k,v上使用使用类似lora的低秩分解,减少了总参数量)来提升模型在推理和训练时的性能。虽然对q也使用了lora,但主要侧重点应该在kv上,因为kv能减少kv cache的存储量。但有个疑问想请教一下,使用在预训练时就全使用lora与非lora的效果差异会很大吗,毕竟lora在微调时候就是没有全参数微调效果好的,或者说效果可以等同于一个更小的矩阵。那我直接用一个更小的矩阵,然后只存每个transformer layer的outoput,kv通过小$W_k$和小$W_q$复原出来是不是也可以呢
1、MLA跟LoRA没什么关系,不建议从LoRA角度理解MLA,MLA纯粹是GQA的一般化,所谓低秩投影,在GQA也有;
2、“效果可以等同于一个更小的矩阵”不知道想表达什么;
3、“kv通过小$W_k$和小$W_q$复原出来是不是也可以呢”,如果我没理解错,你是想表达只存$x$,然后实时投影生成$q,k,v$?这个做法本文已经有评论。
kv和Wq有啥关系。。
June 18th, 2024
大神,请教一下,为什么HF的deepseek的apply_rotary_pos_emb函数中,进行了q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d), 以前不都是直接用吗?这和MLA有关?
https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/4461458f186c35188585855f28f77af5661ad489/modeling_deepseek.py#L364
这应该对应RoPE两种不同的实现而已,理论上等价的,不知道它出于什么考虑。
July 2nd, 2024
苏神好,请问应用GQA时,是不是也要适配下RoPE,毕竟key 分组共享了,可以看做hidden_dim维度和query 不同了,而RoPE里面的旋转角度是hidden_dim的函数,两者该如何匹配呢 ?
$\theta$从来就不是hidden_dim的函数,它是head_dim的函数。
July 8th, 2024
苏神,12式的$\boldsymbol{k}_i^{(s)}$,是否不应该有上标s?这里应该没有多头才对
你也可以理解为是有多头,只不过每个头共享同一个数值,这就好比说$f(x)=1$(恒等于1)虽然不随着$x$的变化而变化,但我们也可以称之为$x$的函数。
谢谢苏神回复,我还是有些疑惑。加上多头隐含的含义是,这里必须对key进行broadcast,才能与query计算Attention。但这种实现方式,会使得推理成本大幅上升(头的维度从128变成了512+64)。实际上,是可以通过优化flash-attention,使得这里不需要对key进行broadcast。Deepseek没详细讲他们在推理阶段做的优化,但我认为如果不魔改flashattention,长文的首字符延迟会很恐怖。 还请赐教。
MLA的意思是,推理时head_dims增加所带来的计算量增加,还比不上带宽本身的速度瓶颈,所以不会影响生成速度。
按照这个思路,魔改flash-attention估计不太行,head_size的维度变为(512+64)了,但是flash-attention目前head_size也就支持到256;还得再看看
这就超出我的理解能力范畴了。
这里每个头的维度是128+64,问题不大
August 6th, 2024
MLA看起来只是缩小了需要缓存的内容大小,最终还是要拆成$128*128$维度的,无法压缩的QKV向量,进行attention操作,对于显存带宽瓶颈似乎没有优化啊。
总量少了不就行了吗
August 14th, 2024
您好,想请教一下既然Wq和Wk的合并可以节省计算量,为什么在训练的时候就合并在一起呢,这样既训得快,并且免去了训练和推理不一致带来的麻烦。
本文似乎从来没说过这一步合并能够节省计算量。本文要表达的是,对于MLA,这一步合并会增加计算量,但最终的推理时间还是能缩短,因为推理阶段计算量不是推理时间的瓶颈,但训练阶段计算量是瓶颈,所以训练阶段不能合并。
苏神辛苦啦,这篇文章认认真真看了一天,很有收获~