MoE环游记:3、换个思路来分配
By 苏剑林 | 2025-03-05 | 22912位读者 |这篇文章我们继续探讨MoE的负载均衡问题。在上一篇文章《MoE环游记:2、不患寡而患不均》中,我们主要讨论了通过Aux Loss来促进负载均衡的思路。Aux Loss固然简单直观,但它也有一个明显的缺点——权重不好调——调低了无法促进均衡,调高了容易损害LM Loss,所以业界一直有寻找替代方案的尝试。
本文要分享的是名为“Loss-Free”的方案,由DeepSeek在《Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts》提出。和DeepSeek众多耀眼的开源作品相比,这篇论文也许不算起眼,但在笔者看来,它潜在的学术影响力可能远超其他工作,因为所提方法不仅简单有效,而且极具普适性,堪称经典。
方法大意 #
面对负载不均衡,Aux Loss的应对思路是通过额外的损失引导Router给出均衡的打分,而Loss-Free的想法则是换个新的分配思路,即不改变Router现有打分结果,而是改变$\mathop{\text{argtop}}_k \boldsymbol{\rho}$这个分配方式。
其实这个方向此前也有过一些努力。比如2021年Facebook提出了BASE Layer,将Expert的分配视为线性指派问题,即以负载均衡为约束条件,求在该约束之下Router总打分尽可能高的分配结果,这可以用匈牙利算法等来解决。但该方案需要知道全体Token的打分,所以对于自回归式LLM来说,它只适用于训练,推理还是只能用$\mathop{\text{argtop}}_k \boldsymbol{\rho}$,训练推理存在不一致性,并且由于目前求解算法的限制,它只适用于$k=1$的场景。
相比之下,Loss-Free的做法非常简单且有效,它留意到一个事实,即我们总可以引入一个偏置项$\boldsymbol{b}$,使得$\mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}$的分配是均衡的,所以它将MoE的形式改为
\begin{equation}\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i\end{equation}
这里的$\boldsymbol{b}$是输入无关的向量,由训练过程确定下来,训练完后它就保持不变,因此推理阶段也可以用,换言之训练和推理具有一致的形式。注意乘以$\boldsymbol{e}_i$的还是$\rho_i$而不是$\rho_i + b_i$,也就是说$\boldsymbol{b}$仅仅参与分配过程而不参与MoE的前向计算,所以我们对$\boldsymbol{b}$或$\boldsymbol{\rho} + \boldsymbol{b}$的正负性都没有特殊要求。
手搓梯度 #
怎么训练$\boldsymbol{b}$呢?我们知道,$\boldsymbol{b}$的优化方向自然是促进负载均衡,为此按照上一篇的记号,我们先定义$\boldsymbol{f}=[f_1,f_2,\cdots,f_n]$:
\begin{equation}f_i = \left\{\begin{aligned}1/k, \quad i\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}+\boldsymbol{b} \\
0, \quad i\not\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}+\boldsymbol{b}\end{aligned}\right.\end{equation}
以及$\boldsymbol{F}=\mathbb{E}[\boldsymbol{f}]$,这里的$\boldsymbol{F}$自然就是在$\boldsymbol{b}$偏置下Expert当前的负载分布了。借着我们定义均匀分布为$\boldsymbol{Q}=(1/n,1/n,\cdots,1/n)$,那么负载均衡就相当于最小化
\begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (F_i - 1/n)^2\end{equation}
这个目标是不可导的,但有了上一篇的经验,我们知道STE(Straight-Through Estimator)可以解决这个问题。STE的关键是找一个可导且跟$\boldsymbol{F}$具有同增减趋势的量作为$\boldsymbol{F}$的光滑近似,这里我们的优化参数只有$\boldsymbol{b}$,而它正好具有我们期望的性质(增大$b_i$,$i$被选中的概率就更高,那么$F_i$就更大),所以答案就呼之欲出了:
\begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (b_i + \text{sg}[F_i - b_i] - 1/n)^2\end{equation}
它的梯度是
\begin{equation}\nabla_{\boldsymbol{b}}\mathcal{L}_{\text{aux}} = \frac{1}{2}\nabla_{\boldsymbol{b}}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \boldsymbol{F} - \boldsymbol{Q}\end{equation}
所以用梯度下降(SGD)来更新$\boldsymbol{b}$就是
\begin{equation}\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha (\boldsymbol{F} - \boldsymbol{Q})\end{equation}
这里$\alpha$是$\boldsymbol{b}$的学习率。不过Loss-Free最终选择的更新规则略有不同,它选择的是符号梯度下降(SignSGD):
\begin{equation}\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha \mathop{\text{sign}}(\boldsymbol{F} - \boldsymbol{Q})\label{eq:aux-loss-free}\end{equation}
这个结果其实也很好理解,就是如果$F_i$比$1/n$大,那么就调小一点$b_i$,否则就增大一点$b_i$。
一脉相承 #
原论文在介绍Loss-Free时,并没有上述Aux Loss的推导过程,而是直接给出式$\eqref{eq:aux-loss-free}$的更新规则,给人的感觉是给$\boldsymbol{b}$“手搓”了梯度$\mathop{\text{sign}}(\boldsymbol{F} - \boldsymbol{Q})$,这也是它Loss-Free这个名字的来源。
然而,从本文给出的推导可以看出,更新规则$\eqref{eq:aux-loss-free}$也完全可以从Aux Loss视角得到,两者是一脉相承的。看起来Loss-Free最直接的好处是不用调Aux Loss权重了,但它实际上也有个学习率参数$\alpha$要调,尽管原论文已经帮我们搜好$\alpha=0.001$这个默认值,但不可否认这个超参数是存在的。
在笔者看来,Loss-Free的本质创新并不是没有Aux Loss,而是隔离了Aux Loss和LM Loss的优化参数,从而达到了负载均衡和模型能力两不误的效果。其中最关键一步,是留意到“一个偏置项足以达到负载均衡”这一事实,然后就让Aux Loss只优化新引入的偏置$\boldsymbol{b}$,而LM Loss则优化剩余参数,让Aux Loss对LM Loss的负面作用降到最低。
相比之下,常规的Aux Loss方案需要全体参数来促进负载均衡,而LM Loss优化的也是全体参数,两者的优化方向可能并不完全兼容,因此想找到一个最优的平衡点相对来说就更为困难。所以,Loss-Free基于“一个偏置项足以达到负载均衡”将两个Loss的优化参数隔离开来,是负载均衡问题的一个绝妙的解决办法。
相关细节 #
尽管Loss-Free已经足够简单明了,但是在使用的时候还要稍微注意一些细节。
首先,对于每个Batch的数据,我们应当先根据LM Loss来更新模型参数,然后再根据式$\eqref{eq:aux-loss-free}$来更新$\boldsymbol{b}$。这是因为$\boldsymbol{b}$的更新依赖于全体Token的统计信息$\boldsymbol{F}$,先更新$\boldsymbol{b}$再更新模型其余参数的话,原则上会有泄漏未来信息的风险。虽然直观看来就一个向量$\boldsymbol{b}$泄漏不了多少信息,但这个风险终归是存在的,因此要尽量去规避它。
其次,刚才我们说原论文已经调好$\alpha=0.001$,但这个结果可能跟原论文用Sigmoid作为Router $\boldsymbol{\rho}$激活函数的选择是绑定的。原因也不难想,经过Sigmoid后,每个$\rho_i$相对比较独立,并且都在$(0,1)$内,$\alpha=0.001$相当于说每一步的更新幅度约为千分之一,如果换Softmax、ReLU或者其他激活函数,那么就可能需要重调$\alpha$了。
针对这个问题,笔者建议的做法是结构Gate和Bias所用的激活函数,即
\begin{equation}\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}^{(\sigma)} + \boldsymbol{b}} \rho_i^{(h)} \boldsymbol{e}_i\end{equation}
其中$\boldsymbol{\rho}^{(\sigma)} = \sigma(\boldsymbol{x}\boldsymbol{W}^{(R)}), \boldsymbol{\rho}^{(h)} = h(\boldsymbol{x}\boldsymbol{W}^{(R)})$,$\sigma(\cdot)$是Sigmoid函数,$h(\cdot)$是任意单调且值域非负的函数,说白了就是加上$\boldsymbol{b}$的是Sigmoid激活的打分,这样我们就可以复用$\alpha=0.001$,至于乘上Expert的Gate,我们可以用其他激活函数,只要它的单调性跟Sigmoid一致就行。
此外,由于更新规则$\eqref{eq:aux-loss-free}$加了$\text{sign}$函数,因此有可能训出绝对值大于1的$b_i$,整体绝对值还可能越来越大,这些都是正常的,对模型效果不会有影响。实际上$\boldsymbol{b}$有一个冗余的自由度,因为全体$b_i$都加上同一个常数后,$\mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}$的结果不变。这个额外的自由度我们可以用来做其他好玩的事情(且听下回分解)。
延伸思考 #
除了MoE的负载均衡之外,Loss-Free的思想还可以应用到很多类似问题,比如VQ-VQE的编码表坍缩(Codebook Collapse),就可以用同样思路解决,而且相比之前介绍的“旋转技巧”、“线性变换技巧”显得更自然和普适。事实上,本文开篇的评价“Loss-Free潜在的学术影响力可能远超其他工作”,正是基于Loss-Free的普适性考虑的。
抛开具体的应用背景,从数学上来看,Loss-Free的贡献可以理解为给出了用梯度下降来求解指派问题的方法。一个经典的线性指派问题可以表示为:
\begin{equation}\min_f \sum_{i=1}^n c_{i, f(i)}\end{equation}
其中$c_{i,j}$是给定的成本函数,$f$是$\{1,2,\cdots,n\}$到自身的双射。放到本文的背景下,$c_{i,j}$不就相当于$n$个Token、$n$个Expert的打分,所求$f$不就是一个负载均衡的分配方案?求解此类问题的一般想法是在满足约束条件的空间里搜索尽可能优的解,而Loss-Free则反过来,先构建一个最优但不一定满足约束条件的解:
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j}\end{equation}
这个解在分数上肯定是最优的,但不一定满足双射的条件,这里不满足双射就等价于负载不均衡。于是我们引入偏置
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j} + b_j\end{equation}
$b_j$初始化为零,然后根据式$\eqref{eq:aux-loss-free}$来更新,更新规则说白了就是哪个$j$出现出现次数多,那减少相应的$b_j$,反之增加,直到出现双射为止。
文章小结 #
本文介绍了MoE负载均衡问题的Loss-Free方法,它由DeepSeek提出,其核心在于通过引入一个简单的偏置项来实现负载均衡。本文进一步思考了它与Aux Loss的联系,以及它在类似数学问题上的应用潜力。
转载到请包括本文地址:https://kexue.fm/archives/10757
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 05, 2025). 《MoE环游记:3、换个思路来分配 》[Blog post]. Retrieved from https://kexue.fm/archives/10757
@online{kexuefm-10757,
title={MoE环游记:3、换个思路来分配},
author={苏剑林},
year={2025},
month={Mar},
url={\url{https://kexue.fm/archives/10757}},
}
March 5th, 2025
想过做vq-vae,不过我不会视觉hh,能不能等苏神搓一个
March 5th, 2025
也许是我没有理解,但之前的方法Aux loss的话会使得$\rho$本身是均衡的吧
这篇文章介绍的方法,似乎会使得$\rho+b$是均衡的,$\rho$本身仍然可能会不均衡
但计算时依然会选择$\rho$做为权重,因此虽然专家可能被选择了,但权重很小,这样能起到均衡的意义吗?也许训练的时候每个专家都“干活”了,但可能并没有起到“贡献”?
“因此虽然专家可能被选择了,但权重很小”, 但是还有LLM的cross entropy loss,要想ce loss 小,你说的这件事就不应该普遍出现。不然就是没训练好
原来如此,感谢!
反正训练时就已经选出了这个Expert了,如果模型觉得有必要,那么自然就会将相应的$\rho_i$调大,否则就是没必要的。这点选择相信优化器就行,不用太担心。
原来如此,感谢!
March 5th, 2025
苏神你好,我对这个偏置项的物理意义的理解是,它是由这个负载均衡问题进行多槽在线匹配 (multi slot online matching)建模后得到的0/1整数规划问题的对偶变量。我最近做了一个follow的工作(https://arxiv.org/pdf/2502.15451,目前只有算法分析还没做完实验),欢迎交流。
欢迎作者莅临!之前关注到你这篇paper了。我有一个顾虑是,你这篇paper主要强调的点是能更快达到均衡,我之前也尝试在单个batch内执行多次$\boldsymbol{b}$的更新,以达到类似的效果(即尽快达到均衡),但我个人测试发现过早达到均衡并不一定有利于模型最终表现,此外就是整数规划的解法,拓展性似乎不如SGD好。
感谢您的回复和中肯的建议。我近期更新了论文,补充了LLM实测实验(专家16取4和64取8,基于Minimind),实验结果显示求解整数规划是可以做到在整个预训练过程中从第一步到最后一步,均实现每个 MoE 层中每个专家都保持负载均衡状态,同时经过训练的 MoE 模型也表现良好(更低的Perplexity)。附上项目代码供参考:https://github.com/sunyuanLLM/bip_routing_algorithm。
恭喜!同时感谢后续反馈,正在学习中。
March 5th, 2025
Loss-Free的本质创新并不是没有Aux Loss,而是隔离了Aux Loss和LM Loss的优化参数,从而达到了负载均衡和模型能力两不误的效果。
---
请教一下苏神,这里的“隔离”应该如何理解?为什么加了b之后就隔离了两个loss的优化参数呢?
是不是因为训练b的梯度更新抵消了F-Q的影响,从而降低了因为负载均衡优化对LM loss优化的影响。可以这么理解吗?
平衡信息F只出现在b的更新公式中。所以平衡不平衡这件事,只会影响b这个参数,不会影响其他参数。其他参数更新的梯度只会从最后的cross-entropy传过来。
但是原始的aux loss,cross-entropy和balance loss是加在一起的。这个值会同时影响所有参数。
当然这里只说直接影响。间接地话,F也是会对LLM其他参数有影响的,你选的专家不一样,当然对每个专家对训练过程是不一样的。但是直接的影响是隔离里
当然反过来,可以这么做的原因是他确实是可以隔离的,因为F是不是平衡,只通过b就可以解决了
现在是人为规定,Aux Loss只用来更新$\boldsymbol{b}$,LM Loss只用来更新剩余参数,这不就是很明显的隔离了嘛。常规的Aux Loss没有这个约定,两个Loss都更新所有参数。
March 5th, 2025
“首先,对于每个Batch的数据,我们应当先根据LM Loss来更新模型参数,然后再根据式(7)
来更新b” 感觉没有说清楚。
原文是说:we update the biases based on the historical balance condition, since utilizing the load information of the current sequence will break the causal constraint of language modeling, leading to leakage of the information of future tokens
我的理解是当前step的llm权重更新用当前batch的数据来。但是bias的更新根据上一个batch的数据来?
不是上一个batch,是所有i-1个batch,除了当前batch
原论文这段话我也不确定它要表达的准确意思是什么,但从原论文的Algorithm 1来看,应该是按照我说的理解。
另外阁下是否可以考虑换个昵称,这个昵称似乎对交流并不那么友好。
我又看了一下 你这个描述确实是符合alg1的。可能他强调的是在第$k$轮更新$b_i$的时候,$c_i$和$\bar c_i$要用$k-1$轮时候的统计数据吧。之前还有一句:
To be specific, for each bi, we keep monitoring its corresponding expert load on the previous batch. If an expert has a heavy load on the previous batch, we will reduce its bias.
哈哈 改名了
理论上来讲,直接用当前batch的数据,在LM Loss更新之后去更新bias,不会有泄漏问题,我们其实也是这样做的,并没有发现什么问题,所以就这样写了。
March 5th, 2025
感觉aux loss是让语言模型本身的训练掺和到balance到训练来了,具体体现就是F/P的梯度直接指导了LLM的训练。这样是不是可能导致为了平衡而平衡。
比如,之所以某一个token给某个专家较大的权重,根本原因还是这个专家活干得好。那为了平衡而平衡的话,那这个专家有可能为了降低自己的工作压力 故意降低自己的能力。
aux loss free,就是说专家该怎么干就怎么干,一切以干好活为最高原则。平衡的事我另外搞一个机制强行保证平衡。
“可能为了降低自己的工作压力,故意降低自己的能力”,这个类比很形象!受教了
March 6th, 2025
感觉这个很像,对长尾分布的数据打伪标签的时候,为了避免集中在某个类别,因此给伪标签减去一个和类别占比成正比的惩罚因子
有点像 https://kexue.fm/archives/7615 的思路吧,当时确实想过类似策略,但不如直接用SGD来更新Bias直接。
March 6th, 2025
您好!我想探讨的是,为什么DeepSeek-V3会选择Sigmoid来作为Gate分数的激活呢?通常意义来说,softmax不是更倾向于去选择"argmax"的操作嘛!期待回复
看源代码的话,gating的分数其实是sigmoid+topk+l1norm,会比softmax平缓很多。如果用softmax,会把专家间的分数差异拉的很大。可能出现一两个专家的weight很大,其他的分数几乎为0.但是用sigmoid+l1norm就会让专家分数更趋于平均。
比如[1,2,3,4], softmax是[0.0321, 0.0871, 0.2369, 0.6439]
前两个值几乎为0
sigmoid+l1norm为[0.2061, 0.2484, 0.2686, 0.2769]
大概作者是像让分数小但是被选中的专家也不要分数被压制的太厉害吧。
其实就像你说的softmax倾向于argmax,但是模型不希望只选最佳的那个专家,比如deepseek v3 671B要选top8个专家。如果用softmax,可能虽然选了8个专家,但是他们的权重可能是top1的那个专家占98%,其他7个都接近0。这不是作者希望的,他可能希望既然选了8个就尽量都用上,而不失选了8个实际有7个的贡献都忽略不计,毕竟算力是付出了8个专家的算力
感谢您解惑!非常清晰!
因为$b_i$是独立于$s_i$的计算的,如果不用sigmoid也独立计算,而专用softmax,就可能让本来同样的$x_1,x_2, b_1,b_2$的情况下既有可能$s_1+b_1>s_2+b_2$也有可能$s_1+b_1 < s_2+b_2$,而这显然是不理想的。让我们举一个例子来说明,假设$b_1=0.15, b_2=0$,
情况一 $x = [2, 2.5, 3]$, $softmax(x) = [0.18632372 0.30719589 0.50648039]$, 这样 $s_1 + b_1 = 0.336 > s_2 + b_2 = 0.307$ 1,2之间应该选1
情况二 $x = [2, 2.5, 2]$, $softmax(x) = [0.27406862 0.45186276 0.27406862]$, 这样 $s_1 + b_1 = 0.424 < s_2 + b_2 = 0.451$, 1,2之间应该选2
感谢回复!但是我认为本质上使用sigmoid的原因还是楼上和楼下苏老师所说,希望topk之间能尽可能避免one-hot
sigmoid和softmax的差距应该是sigmoid激活和指数exp激活的区别,因为softmax是exp激活+$l_1$norm,所以其实和sigmoid思路保持一致,在指数exp激活后添加bias,就不会有您所表示的问题。(流程相当于 激活 -> 加偏差 -> topK -> $l_1$norm)
Sigmoid一来可以避免选出来的Expert之间出现恶性竞争,二来也更配合Loss-Free。
Softmax的问题是此消彼长,输出结果会比较容易接近one hot,这就容易导致选出了Top-$k$个Expert,但实际上只有Top-1个Expert在工作,这是其一;然后,Softmax接近one hot时,Top-1以外的打分就会很接近于0,并且它们之间差别不大,这时候我们用$\alpha=0.001$就比较难调平衡,因为它们之间的差异可能都小于0.001。
感谢苏老师解惑!
March 10th, 2025
苏神,想请加一下 Deepseek V3 中使用的 seq level aux loss 计算应该带上 bias 吧(一个是 score,一个是 topk 专家的 indices),论文公式里面都没有带 bias,但显然这是不均衡的。
DeepSeek V3的Aux Loss是想促进$\rho_i$自身的平衡,而不是加强Loss-Free的平衡,所以不用加$b_i$。
所以 indices 也要按照 node limited routing 的方式针对为 $\rho$ 单独计算而非主线 forward 中使用的 $\rho + b$
按我的理解是的,Aux Loss起保底作用,用$\boldsymbol{rho}$不要飞得太夸张,在此基础上Loss-Free更容易控制balance。
March 19th, 2025
想请教下苏神,为什么V3论文里说Aux-Loss是seq-level,Loss-free是batch-level?我看都是在整个batch上优化P/b。无论是 $PF$ 还是 $\text{sign}(F-1/n)$,都是在batch上取平均。
DeepSeek-V3似乎没公开训练源码?所以在哪个level做平均似乎是取决于他们的意愿?