MoE环游记:3、换个思路来分配
By 苏剑林 | 2025-03-05 | 2058位读者 |这篇文章我们继续探讨MoE的负载均衡问题。在上一篇文章《MoE环游记:2、不患寡而患不均》中,我们主要讨论了通过Aux Loss来促进负载均衡的思路。Aux Loss固然简单直观,但它也有一个明显的缺点——权重不好调——调低了无法促进均衡,调高了容易损害LM Loss,所以业界一直有寻找替代方案的尝试。
本文要分享的是名为“Loss-Free”的方案,由DeepSeek在《Auxiliary-Loss-Free Load Balancing Strategy for Mixture-of-Experts》提出。和DeepSeek众多耀眼的开源作品相比,这篇论文也许不算起眼,但在笔者看来,它潜在的学术影响力可能远超其他工作,因为所提方法不仅简单有效,而且极具普适性,堪称经典。
方法大意 #
面对负载不均衡,Aux Loss的应对思路是通过额外的损失引导Router给出均衡的打分,而Loss-Free的想法则是换个新的分配思路,即不改变Router现有打分结果,而是改变$\mathop{\text{argtop}}_k \boldsymbol{\rho}$这个分配方式。
其实这个方向此前也有过一些努力。比如2021年Facebook提出了BASE Layer,将Expert的分配视为线性指派问题,即以负载均衡为约束条件,求在该约束之下Router总打分尽可能高的分配结果,这可以用匈牙利算法等来解决。但该方案需要知道全体Token的打分,所以对于自回归式LLM来说,它只适用于训练,推理还是只能用$\mathop{\text{argtop}}_k \boldsymbol{\rho}$,训练推理存在不一致性,并且由于目前求解算法的限制,它只适用于$k=1$的场景。
相比之下,Loss-Free的做法非常简单且有效,它留意到一个事实,即我们总可以引入一个偏置项$\boldsymbol{b}$,使得$\mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}$的分配是均衡的,所以它将MoE的形式改为
\begin{equation}\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i\end{equation}
这里的$\boldsymbol{b}$是输入无关的向量,由训练过程确定下来,训练完后它就保持不变,因此推理阶段也可以用,换言之训练和推理具有一致的形式。注意乘以$\boldsymbol{e}_i$的还是$\rho_i$而不是$\rho_i + b_i$,也就是说$\boldsymbol{b}$仅仅参与分配过程而不参与MoE的前向计算,所以我们对$\boldsymbol{b}$或$\boldsymbol{\rho} + \boldsymbol{b}$的正负性都没有特殊要求。
手搓梯度 #
怎么训练$\boldsymbol{b}$呢?我们知道,$\boldsymbol{b}$的优化方向自然是促进负载均衡,为此按照上一篇的记号,我们先定义$\boldsymbol{f}=[f_1,f_2,\cdots,f_n]$:
\begin{equation}f_i = \left\{\begin{aligned}1/k, \quad i\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}+\boldsymbol{b} \\
0, \quad i\not\in \mathop{\text{argtop}}\nolimits_k \boldsymbol{\rho}+\boldsymbol{b}\end{aligned}\right.\end{equation}
以及$\boldsymbol{F}=\mathbb{E}[\boldsymbol{f}]$,这里的$\boldsymbol{F}$自然就是在$\boldsymbol{b}$偏置下Expert当前的负载分布了。借着我们定义均匀分布为$\boldsymbol{Q}=(1/n,1/n,\cdots,1/n)$,那么负载均衡就相当于最小化
\begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{F} - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (F_i - 1/n)^2\end{equation}
这个目标是不可导的,但有了上一篇的经验,我们知道STE(Straight-Through Estimator)可以解决这个问题。STE的关键是找一个可导且跟$\boldsymbol{F}$具有同增减趋势的量作为$\boldsymbol{F}$的光滑近似,这里我们的优化参数只有$\boldsymbol{b}$,而它正好具有我们期望的性质(增大$b_i$,$i$被选中的概率就更高,那么$F_i$就更大),所以答案就呼之欲出了:
\begin{equation}\mathcal{L}_{\text{aux}} = \frac{1}{2}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \frac{1}{2}\sum_{i=1}^n (b_i + \text{sg}[F_i - b_i] - 1/n)^2\end{equation}
它的梯度是
\begin{equation}\nabla_{\boldsymbol{b}}\mathcal{L}_{\text{aux}} = \frac{1}{2}\nabla_{\boldsymbol{b}}\Vert\boldsymbol{b} + \text{sg}[\boldsymbol{F}-\boldsymbol{b}] - \boldsymbol{Q}\Vert^2 = \boldsymbol{F} - \boldsymbol{Q}\end{equation}
所以用梯度下降(SGD)来更新$\boldsymbol{b}$就是
\begin{equation}\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha (\boldsymbol{F} - \boldsymbol{Q})\end{equation}
这里$\alpha$是$\boldsymbol{b}$的学习率。不过Loss-Free最终选择的更新规则略有不同,它选择的是符号梯度下降(SignSGD):
\begin{equation}\boldsymbol{b}\leftarrow \boldsymbol{b} - \alpha \mathop{\text{sign}}(\boldsymbol{F} - \boldsymbol{Q})\label{eq:aux-loss-free}\end{equation}
一脉相承 #
原论文在介绍Loss-Free时,并没有上述Aux Loss的推导过程,而是直接给出式$\eqref{eq:aux-loss-free}$的更新规则,给人的感觉是给$\boldsymbol{b}$“手搓”了梯度$\mathop{\text{sign}}(\boldsymbol{F} - \boldsymbol{Q})$,这也是它Loss-Free这个名字的来源。
然而,从本文给出的推导可以看出,更新规则$\eqref{eq:aux-loss-free}$也完全可以从Aux Loss视角得到,两者是一脉相承的。看起来Loss-Free最直接的好处是不用调Aux Loss权重了,但它实际上也有个学习率参数$\alpha$要调,尽管原论文已经帮我们搜好$\alpha=0.001$这个默认值,但不可否认这个超参数是存在的。
在笔者看来,Loss-Free的本质创新并不是没有Aux Loss,而是隔离了Aux Loss和LM Loss的优化参数,从而达到了负载均衡和模型能力两不误的效果。其中最关键一步,是留意到“一个偏置项足以达到负载均衡”这一事实,然后就让Aux Loss只优化新引入的偏置$\boldsymbol{b}$,而LM Loss则优化剩余参数,让Aux Loss对LM Loss的负面作用降到最低。
相比之下,常规的Aux Loss方案需要全体参数来促进负载均衡,而LM Loss优化的也是全体参数,两者的优化方向可能并不完全兼容,因此想找到一个最优的平衡点相对来说就更为困难。所以,Loss-Free基于“一个偏置项足以达到负载均衡”将两个Loss的优化参数隔离开来,是负载均衡问题的一个绝妙的解决办法。
相关细节 #
尽管Loss-Free已经足够简单明了,但是在使用的时候还要稍微注意一些细节。
首先,对于每个Batch的数据,我们应当先根据LM Loss来更新模型参数,然后再根据式$\eqref{eq:aux-loss-free}$来更新$\boldsymbol{b}$。这是因为$\boldsymbol{b}$的更新依赖于全体Token的统计信息$\boldsymbol{F}$,先更新$\boldsymbol{b}$再更新模型其余参数的话,原则上会有泄漏未来信息的风险。虽然直观看来就一个向量$\boldsymbol{b}$泄漏不了多少信息,但这个风险终归是存在的,因此要尽量去规避它。
其次,刚才我们说原论文已经调好$\alpha=0.001$,但这个结果可能跟原论文用Sigmoid作为Router $\boldsymbol{\rho}$激活函数的选择是绑定的。原因也不难想,经过Sigmoid后,每个$\rho_i$相对比较独立,并且都在$(0,1)$内,$\alpha=0.001$相当于说每一步的更新幅度约为千分之一,如果换Softmax、ReLU或者其他激活函数,那么就可能需要重调$\alpha$了。
针对这个问题,笔者建议的做法是结构Gate和Bias所用的激活函数,即
\begin{equation}\boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}} \rho_i \boldsymbol{e}_i\qquad\to\qquad \boldsymbol{y} = \sum_{i\in \mathop{\text{argtop}}_k \boldsymbol{\rho}^{(\sigma)} + \boldsymbol{b}} \rho_i^{(h)} \boldsymbol{e}_i\end{equation}
其中$\boldsymbol{\rho}^{(\sigma)} = \sigma(\boldsymbol{x}\boldsymbol{W}^{(R)}), \boldsymbol{\rho}^{(h)} = h(\boldsymbol{x}\boldsymbol{W}^{(R)})$,$\sigma(\cdot)$是Sigmoid函数,$h(\cdot)$是任意单调且值域非负的函数,说白了就是加上$\boldsymbol{b}$的是Sigmoid激活的打分,这样我们就可以复用$\alpha=0.001$,至于乘上Expert的Gate,我们可以用其他激活函数,只要它的单调性跟Sigmoid一致就行。
此外,由于更新规则$\eqref{eq:aux-loss-free}$加了$\text{sign}$函数,因此有可能训出绝对值大于1的$b_i$,整体绝对值还可能越来越大,这些都是正常的,对模型效果不会有影响。实际上$\boldsymbol{b}$有一个冗余的自由度,因为全体$b_i$都加上同一个常数后,$\mathop{\text{argtop}}_k \boldsymbol{\rho} + \boldsymbol{b}$的结果不变。这个额外的自由度我们可以用来做其他好玩的事情(且听下回分解)。
延伸思考 #
除了MoE的负载均衡之外,Loss-Free的思想还可以应用到很多类似问题,比如VQ-VQE的编码表坍缩(Codebook Collapse),就可以用同样思路解决,而且相比之前介绍的“旋转技巧”、“线性变换技巧”显得更自然和普适。事实上,本文开篇的评价“Loss-Free潜在的学术影响力可能远超其他工作”,正是基于Loss-Free的普适性考虑的。
抛开具体的应用背景,从数学上来看,Loss-Free的贡献可以理解为给出了用梯度下降来求解指派问题的方法。一个经典的线性指派问题可以表示为:
\begin{equation}\min_f \sum_{i=1}^n c_{i, f(i)}\end{equation}
其中$c_{i,j}$是给定的成本函数,$f$是$\{1,2,\cdots,n\}$到自身的双射。放到本文的背景下,$c_{i,j}$不就相当于$n$个Token、$n$个Expert的打分,所求$f$不就是一个负载均衡的分配方案?求解此类问题的一般想法是在满足约束条件的空间里搜索尽可能优的解,而Loss-Free则反过来,先构建一个最优但不一定满足约束条件的解:
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j}\end{equation}
这个解在分数上肯定是最优的,但不一定满足双射的条件,这里不满足双射就等价于负载不均衡。于是我们引入偏置
\begin{equation}f(i) = \mathop{\text{argmin}}_j c_{i,j} + b_j\end{equation}
$b_j$初始化为零,然后根据式$\eqref{eq:aux-loss-free}$来更新,更新规则说白了就是哪个$j$出现出现次数多,那减少相应的$b_j$,反之增加,直到出现双射为止。
文章小结 #
本文介绍了MoE负载均衡问题的Loss-Free方法,它由DeepSeek提出,其核心在于通过引入一个简单的偏置项来实现负载均衡。本文进一步思考了它与Aux Loss的联系,以及它在类似数学问题上的应用潜力。
转载到请包括本文地址:https://kexue.fm/archives/10757
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 05, 2025). 《MoE环游记:3、换个思路来分配 》[Blog post]. Retrieved from https://kexue.fm/archives/10757
@online{kexuefm-10757,
title={MoE环游记:3、换个思路来分配},
author={苏剑林},
year={2025},
month={Mar},
url={\url{https://kexue.fm/archives/10757}},
}
March 5th, 2025
想过做vq-vae,不过我不会视觉hh,能不能等苏神搓一个
March 5th, 2025
也许是我没有理解,但之前的方法Aux loss的话会使得$\rho$本身是均衡的吧
这篇文章介绍的方法,似乎会使得$\rho+b$是均衡的,$\rho$本身仍然可能会不均衡
但计算时依然会选择$\rho$做为权重,因此虽然专家可能被选择了,但权重很小,这样能起到均衡的意义吗?也许训练的时候每个专家都“干活”了,但可能并没有起到“贡献”?
“因此虽然专家可能被选择了,但权重很小”, 但是还有LLM的cross entropy loss,要想ce loss 小,你说的这件事就不应该普遍出现。不然就是没训练好
反正训练时就已经选出了这个Expert了,如果模型觉得有必要,那么自然就会将相应的$\rho_i$调大,否则就是没必要的。这点选择相信优化器就行,不用太担心。
March 5th, 2025
苏神你好,我对这个偏置项的物理意义的理解是,它是由这个负载均衡问题进行多槽在线匹配 (multi slot online matching)建模后得到的0/1整数规划问题的对偶变量。我最近做了一个follow的工作(https://arxiv.org/pdf/2502.15451,目前只有算法分析还没做完实验),欢迎交流。
欢迎作者莅临!之前关注到你这篇paper了。我有一个顾虑是,你这篇paper主要强调的点是能更快达到均衡,我之前也尝试在单个batch内执行多次$\boldsymbol{b}$的更新,以达到类似的效果(即尽快达到均衡),但我个人测试发现过早达到均衡并不一定有利于模型最终表现,此外就是整数规划的解法,拓展性似乎不如SGD好。
March 5th, 2025
Loss-Free的本质创新并不是没有Aux Loss,而是隔离了Aux Loss和LM Loss的优化参数,从而达到了负载均衡和模型能力两不误的效果。
---
请教一下苏神,这里的“隔离”应该如何理解?为什么加了b之后就隔离了两个loss的优化参数呢?
是不是因为训练b的梯度更新抵消了F-Q的影响,从而降低了因为负载均衡优化对LM loss优化的影响。可以这么理解吗?
平衡信息F只出现在b的更新公式中。所以平衡不平衡这件事,只会影响b这个参数,不会影响其他参数。其他参数更新的梯度只会从最后的cross-entropy传过来。
但是原始的aux loss,cross-entropy和balance loss是加在一起的。这个值会同时影响所有参数。
当然这里只说直接影响。间接地话,F也是会对LLM其他参数有影响的,你选的专家不一样,当然对每个专家对训练过程是不一样的。但是直接的影响是隔离里
当然反过来,可以这么做的原因是他确实是可以隔离的,因为F是不是平衡,只通过b就可以解决了
现在是人为规定,Aux Loss只用来更新$\boldsymbol{b}$,LM Loss只用来更新剩余参数,这不就是很明显的隔离了嘛。常规的Aux Loss没有这个约定,两个Loss都更新所有参数。
March 5th, 2025
“首先,对于每个Batch的数据,我们应当先根据LM Loss来更新模型参数,然后再根据式(7)
来更新b” 感觉没有说清楚。
原文是说:we update the biases based on the historical balance condition, since utilizing the load information of the current sequence will break the causal constraint of language modeling, leading to leakage of the information of future tokens
我的理解是当前step的llm权重更新用当前batch的数据来。但是bias的更新根据上一个batch的数据来?
不是上一个batch,是所有i-1个batch,除了当前batch
原论文这段话我也不确定它要表达的准确意思是什么,但从原论文的Algorithm 1来看,应该是按照我说的理解。
另外阁下是否可以考虑换个昵称,这个昵称似乎对交流并不那么友好。
March 5th, 2025
感觉aux loss是让语言模型本身的训练掺和到balance到训练来了,具体体现就是F/P的梯度直接指导了LLM的训练。这样是不是可能导致为了平衡而平衡。
比如,之所以某一个token给某个专家较大的权重,根本原因还是这个专家活干得好。那为了平衡而平衡的话,那这个专家有可能为了降低自己的工作压力 故意降低自己的能力。
aux loss free,就是说专家该怎么干就怎么干,一切以干好活为最高原则。平衡的事我另外搞一个机制强行保证平衡。
“可能为了降低自己的工作压力,故意降低自己的能力”,这个类比很形象!受教了