在之前的文章《当Batch Size增大时,学习率该如何随之变化?》《Adam的epsilon如何影响学习率的Scaling Law?》中,我们从理论上讨论了学习率随Batch Size的变化规律,其中比较经典的部分是由OpenAI提出的展开到二阶的分析。然而,当我们要处理非SGD优化器时,这套分析方法的计算过程往往会相当复杂,有种无从下手的感觉。

接下来的几篇文章,笔者将重新整理和思考上述文章中的相关细节,尝试简化其中的一些推导步骤,给出一条更通用、更轻盈的推导路径,并且探讨推广到Muon优化器的可能性。

方法大意 #

首先回顾一下之前的分析方法。在《当Batch Size增大时,学习率该如何随之变化?》中,我们介绍了多种分析学习率与Batch Size规律的思路,其中OpenAI在《An Empirical Model of Large-Batch Training》提出的二阶近似分析占了主要篇幅,本文也是沿用同样的思路。

接着需要引入一些记号。设损失函数为$\mathcal{L}(\boldsymbol{w})$,$\boldsymbol{w}\in\mathbb{R}^N$是参数向量,$\boldsymbol{g}$是它的梯度。注意理想的损失函数是在全体训练样本上算的期望,但实际我们只能采样一个Batch来算,这导致梯度也带有随机性,我们将单个样本的梯度记为$\tilde{\boldsymbol{g}}$,它的均值就是$\boldsymbol{g}$,而协方差矩阵记为$\boldsymbol{\Sigma}$;当Batch Size为$B$时,梯度记为$\tilde{\boldsymbol{g}}_B$,它的均值还是$\boldsymbol{g}$,但协方差矩阵变为$\boldsymbol{\Sigma}/B$。

进一步地,设当前学习率为$\eta$,更新向量为$\tilde{\boldsymbol{\varphi}}_B$,那么更新后的损失函数将是
\begin{equation}\begin{aligned}
\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B) \approx&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H}\tilde{\boldsymbol{\varphi}}_B \\
=&\, \mathcal{L}(\boldsymbol{w}) - \eta \tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2\newcommand{tr}{\mathop{\text{tr}}}\tr(\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}\boldsymbol{H})
\end{aligned}\end{equation}
右侧我们泰勒展开到了二阶,$\boldsymbol{H}$是Hessian矩阵,$\tr$是矩阵的迹,第二个等号用到了$\tr(\boldsymbol{A}\boldsymbol{B})=\tr(\boldsymbol{B}\boldsymbol{A})$这个恒等式。为了得到一个确定性的结果,我们对两边求期望:
\begin{equation}\mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{\varphi}}_B)] \approx \mathcal{L}(\boldsymbol{w}) - \eta\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g} + \frac{1}{2}\eta^2 \tr(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})\end{equation}
我们把右端看成是关于$\eta$的二次函数,并假设二次项系数是正的(更强的假设是$\boldsymbol{H}$矩阵是正定的),那么可以得到最小值点
\begin{equation}\eta^* \approx \frac{\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]^{\top}\boldsymbol{g}}{\tr(\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]\boldsymbol{H})}\end{equation}
这便是平均来说让损失函数下降最快的学习率,是学习率的理论最优解。我们要做的事情,就是针对具体的$\tilde{\boldsymbol{\varphi}}_B$算出$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$和$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$,然后从上式析出它与Batch Size(即$B$)的关系。

热身练习 #

作为第一个例子,我们自然是考虑最简单的SGD,此时有$\tilde{\boldsymbol{\varphi}}_B=\tilde{\boldsymbol{g}}_B$,那么简单可得$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]=\boldsymbol{g}$以及$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]=\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B$,于是有
\begin{equation}\eta^* \approx \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\tr((\boldsymbol{g}\boldsymbol{g}^{\top} + \boldsymbol{\Sigma}/B)\boldsymbol{H})} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g} + \tr(\boldsymbol{\Sigma}\boldsymbol{H})/B} = \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B}\label{eq:eta-sgd}\end{equation}
其中
\begin{equation}\eta_{\max} = \frac{\boldsymbol{g}^{\top}\boldsymbol{g}}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}},\qquad\mathcal{B}_{\text{noise}} = \frac{\tr(\boldsymbol{\Sigma}\boldsymbol{H})}{\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}\end{equation}

对于结果$\eqref{eq:eta-sgd}$,我们可以有多种解读方式。首先,它是一个单调递增但有上界的函数,上界为$\eta_{\max}$,这表明学习率不能无限增加,相比简单的线性律或者平方根律,它更符合我们的直觉认知;当$B \ll \mathcal{B}_{\text{noise}}$时,我们有
\begin{equation}\eta^* \approx \frac{\eta_{\max}}{1 + \mathcal{B}_{\text{noise}}/B} \approx \frac{\eta_{\max}}{\mathcal{B}_{\text{noise}}/B} = \eta_{\max} B / \mathcal{B}_{\text{noise}}\end{equation}
这表明在Batch Size比较小时,SGD的学习率与Batch Size确实呈线性关系,同时也暗示了$\mathcal{B}_{\text{noise}}$是一个关键统计量。不过$\mathcal{B}_{\text{noise}}$的定义依赖于Hessian矩阵$\boldsymbol{H}$,这在LLM中是几乎不可能精确计算的,所以实践中我们通常假设它是单位阵(的若干倍),得到一个简化的形式
\begin{equation}\mathcal{B}_{\text{simple}} = \frac{\tr(\boldsymbol{\Sigma})}{\boldsymbol{g}^{\top}\boldsymbol{g}}\end{equation}
该结果具有噪音强度($\tr(\boldsymbol{\Sigma})$)除以信号强度($\boldsymbol{g}^{\top}\boldsymbol{g}$)的形式,它其实就是信噪比的倒数,它表明信噪比越小,那么就需要更大的Batch Size才能用上相同的$\eta_{\max}$,这也跟我们的直觉认知相符。$\tr(\boldsymbol{\Sigma})$只依赖于$\boldsymbol{\Sigma}$的对角线元素,这表明我们只需要将每个参数独立地估计均值和方差,这在实践上是可行的。

数据效率 #

除了学习率与Batch Size的直接关系外,笔者认为由此衍生出来的关于训练数据量和训练步数的渐近关系,也是必须要学习的精彩部分。特别地,这个结论似乎比学习率的关系式$\eqref{eq:eta-sgd}$更为通用,因为后面我们将会看到,SignSGD也能得到同样形式的结论,但它的学习率规律并不是式$\eqref{eq:eta-sgd}$。

原论文对这部分的讨论比较复杂,下面的推导是经过笔者简化的。具体来说,我们将$\eta^*$代回到$\mathcal{L}(\boldsymbol{w} - \eta\tilde{\boldsymbol{g}}_B)$,将得到
\begin{equation}\overline{\Delta\mathcal{L}} = \mathcal{L}(\boldsymbol{w}) - \mathbb{E}[\mathcal{L}(\boldsymbol{w} - \eta^*\tilde{\boldsymbol{g}}_B)] \approx \frac{\Delta\mathcal{L}_{\max}}{1 + \mathcal{B}_{\text{noise}}/B}\end{equation}
其中$\Delta\mathcal{L}_{\max} = \frac{(\boldsymbol{g}^{\top}\boldsymbol{g})^2}{2\boldsymbol{g}^{\top}\boldsymbol{H}\boldsymbol{g}}$。怎么理解这个结果呢?首先,它是关于$B$的单调递增函数,当$B\to\infty$时等于$\Delta\mathcal{L}_{\max}$,换言之如果我们能开无穷大的Batch Size,那么每一步的损失下降量是$\Delta\mathcal{L}_{\max}$,此时所需的训练步数最少,记为$S_{\min}$。

如果Batch Size是有限值,那么每一步的损失下降量平均来说只有$\overline{\Delta\mathcal{L}}$,这意味着平均而言我们要花$1 + \mathcal{B}_{\text{noise}}/B$步,才能达到无穷大Batch Size时1步的下降量,于是为了达到相同的损失,我们要训练$S = (1 + \mathcal{B}_{\text{noise}}/B)S_{\min}$步。

由于Batch Size为$B$,所以很容易得出训练消耗的数据总量为$E = BS = (B + \mathcal{B}_{\text{noise}})S_{\min}$。从这个结果可以看出,增大Batch Size后,想要达到相同的效果,我们还需要适当增加数据量$E$;当$B\to 0$时,所需要的数据量最少,为$E_{\min} = \mathcal{B}_{\text{noise}}S_{\min}$。利用这些记号,我们可以写出
\begin{equation}\left(\frac{S}{S_{\min}} - 1\right)\left(\frac{E}{E_{\min}} - 1\right) = 1\end{equation}
这便是训练数据量和训练步数的经典关系式,它有两个参数$S_{\min},E_{\min}$,我们也可以通过实验搜索多个$(S,E)$来拟合上式,从而估计$S_{\min},E_{\min}$,进而可以估算$\mathcal{B}_{\text{noise}} = E_{\min} / S_{\min}$。更多分析细节请看回之前的文章《当Batch Size增大时,学习率该如何随之变化?》或OpenAI的原论文《An Empirical Model of Large-Batch Training》

困难分析 #

前面写了那么多,都还停留在SGD中。从计算角度看,SGD是平凡的,真正复杂的是$\tilde{\boldsymbol{\varphi}}_B$非线性地依赖于$\tilde{\boldsymbol{g}}_B$的情形,比如SignSGD对应于$\newcommand{sign}{\mathop{\text{sign}}}\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$,在理论分析中它经常用作Adam的近似,更准确的近似则是考虑了$\epsilon$的SoftSignSGD,我们在《Adam的epsilon如何影响学习率的Scaling Law?》尝试过分析它。

在这些非线性场景下,$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$和$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$的计算往往是相当困难的,即便我们将$\tilde{\boldsymbol{g}}_B$的分布假设为简单的正态分布也是如此(注意,在SGD的分析中,我们并不需要对它的分布形式做正态假设)。比如,在之前的文章中,对于$\tilde{\boldsymbol{\varphi}}_B=\sign(\tilde{\boldsymbol{g}}_B)$的SignSGD,为了计算$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$,我们经历了如下步骤:

1、假设$\tilde{\boldsymbol{g}}_B$的分量相互独立,问题简化为单个分量$\tilde{\varphi}_B=\sign(\tilde{g}_B)$(没有加粗)的期望;

2、假设$\tilde{g}_B$(此时是一个标量)服从正态分布,那么就可以算出$\mathbb{E}[\tilde{\varphi}_B]$,答案要用$\newcommand{erf}{\mathop{\text{erf}}}\erf$函数来表示;

3、将$\erf$函数用$x/\sqrt{x^2+c}$形式的函数近似,简化结果。

也就是说,我们要经过一堆弯弯绕绕的步骤,才勉强算出一个可以分析下去的近似结果(这个过程首次出现在Tencent的论文《Surge Phenomenon in Optimal Learning Rate and Batch Size Scaling》),而且这已经算是简单的了,因为如果是SoftSignSGD,则更加复杂:

1、假设$\tilde{\boldsymbol{g}}_B$的分量相互独立,问题简化为单个分量$\tilde{\varphi}_B=\newcommand{softsign}{\mathop{\text{softsign}}}\softsign(\tilde{g}_B, \epsilon)$的期望;

2、将$\softsign$函数用分段线性函数近似,这样才能算出下面的积分;

3、假设$\tilde{g}_B$服从正态分布,结合第2步的近似,可以算出$\mathbb{E}[\tilde{\varphi}_B]$,答案是包含$\erf$的复杂函数;

4、将复杂函数用$x/\sqrt{x^2+c}$形式的函数近似,简化结果。

事情还没完。费那么大劲,加那么多假设,我们才堪堪算出$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B]$,接着还要算$\mathbb{E}[\tilde{\boldsymbol{\varphi}}_B\tilde{\boldsymbol{\varphi}}_B^{\top}]$,这往往更加复杂(SignSGD是个例外,因为$\sign(x)^2$一定是1,所以反而简单了)。然而,计算的复杂性还是次要的,主要是这些步骤看上去没有任何能推广的规律,似乎只能具体问题具体分析的样子,这就让人觉得非常心累。

未完待续 #

为了避免文章过长,本文就先到这里了,主要先简单回顾一下现有的分析结果和计算困难。在下一篇文章中,笔者将会介绍自己为了降低推导过程中的心智负担所做的一些尝试。

转载到请包括本文地址:https://kexue.fm/archives/11260

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Sep. 01, 2025). 《重新思考学习率与Batch Size(一):现状 》[Blog post]. Retrieved from https://kexue.fm/archives/11260

@online{kexuefm-11260,
        title={重新思考学习率与Batch Size(一):现状},
        author={苏剑林},
        year={2025},
        month={Sep},
        url={\url{https://kexue.fm/archives/11260}},
}