从Hessian近似看自适应学习率优化器
By 苏剑林 | 2024-11-29 | 1536位读者 |这几天在重温去年的Meta的一篇论文《A Theory on Adam Instability in Large-Scale Machine Learning》,里边给出了看待Adam等自适应学习率优化器的新视角:它指出梯度平方的滑动平均某种程度上近似于在估计Hessian矩阵的平方,从而Adam、RMSprop等优化器实际上近似于二阶的Newton法。
这个角度颇为新颖,而且表面上跟以往的一些Hessian近似有明显的差异,因此值得我们去学习和思考一番。
牛顿下降 #
设损失函数为$\mathcal{L}(\boldsymbol{\theta})$,其中待优化参数为$\boldsymbol{\theta}$,我们的优化目标是
\begin{equation}\boldsymbol{\theta}^* = \mathop{\text{argmin}}_{\boldsymbol{\theta}} \mathcal{L}(\boldsymbol{\theta})\label{eq:loss}\end{equation}
假设$\boldsymbol{\theta}$的当前值是$\boldsymbol{\theta}_t$,Newton法通过将损失函数展开到二阶来寻求$\boldsymbol{\theta}_{t+1}$:
\begin{equation}\mathcal{L}(\boldsymbol{\theta})\approx \mathcal{L}(\boldsymbol{\theta}_t) + \boldsymbol{g}_t^{\top}(\boldsymbol{\theta} - \boldsymbol{\theta}_t) + \frac{1}{2}(\boldsymbol{\theta} - \boldsymbol{\theta}_t)^{\top}\boldsymbol{\mathcal{H}}_t(\boldsymbol{\theta} - \boldsymbol{\theta}_t)\end{equation}
其中$\boldsymbol{g}_t = \nabla_{\boldsymbol{\theta}_t}\mathcal{L}(\boldsymbol{\theta}_t)$是梯度、 $\boldsymbol{\mathcal{H}}_t=\nabla_{\boldsymbol{\theta}_t}^2\mathcal{L}(\boldsymbol{\theta}_t)$是Hessian矩阵。假定Hessian矩阵的正定性,那么上式右端就存在唯一的最小值$\boldsymbol{\theta}_t - \boldsymbol{\mathcal{H}}_t^{-1}\boldsymbol{g}_t$,Newton法将它作为下一步的$\boldsymbol{\theta}_{t+1}$:
\begin{equation}\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t-\boldsymbol{\mathcal{H}}_t^{-1}\boldsymbol{g}_t = \boldsymbol{\theta}_t - (\nabla_{\boldsymbol{\theta}_t}^2\mathcal{L})^{-1} \nabla_{\boldsymbol{\theta}_t}\mathcal{L}\end{equation}
注意上式没有额外的学习率参数,因此Newton法天生就是自适应学习率算法。当然,由于Hessian矩阵的复杂度正比于参数量的平方,所以在深度学习中完整的Newton法基本上只有理论价值了,真要想应用Newton法,要对Hessian矩阵做比较大的简化假设,比如对角矩阵或者低秩矩阵。
在Newton法视角下,SGD就是假设了$\boldsymbol{\mathcal{H}}_t=\eta_t^{-1}\boldsymbol{I}$,而Adam则是假设$\boldsymbol{\mathcal{H}}_t=\eta_t^{-1}\text{diag}(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon)$,其中
\begin{equation}\text{Adam}:=\left\{\begin{aligned}
&\boldsymbol{m}_t = \beta_1 \boldsymbol{m}_{t-1} + \left(1 - \beta_1\right) \boldsymbol{g}_t\\
&\boldsymbol{v}_t = \beta_2 \boldsymbol{v}_{t-1} + \left(1 - \beta_2\right) \boldsymbol{g}_t\otimes\boldsymbol{g}_t\\
&\hat{\boldsymbol{m}}_t = \boldsymbol{m}_t\left/\left(1 - \beta_1^t\right)\right.\\
&\hat{\boldsymbol{v}}_t = \boldsymbol{v}_t\left/\left(1 - \beta_2^t\right)\right.\\
&\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta_t \hat{\boldsymbol{m}}_t\left/\left(\sqrt{\hat{\boldsymbol{v}}_t} + \epsilon\right)\right.
\end{aligned}\right.\end{equation}
接下来我们想要证明的是,$\eta_t^{-1}\text{diag}(\sqrt{\hat{\boldsymbol{v}}_t})$是$\boldsymbol{\mathcal{H}}_t$的一个更好的近似。
梯度近似 #
证明的要点是利用梯度的一阶近似:
\begin{equation}\boldsymbol{g}_{\boldsymbol{\theta}} \approx \boldsymbol{g}_{\boldsymbol{\theta}^*} + \boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}(\boldsymbol{\theta} - \boldsymbol{\theta}^*)\end{equation}
其中$\boldsymbol{g}_{\boldsymbol{\theta}^*}$和$\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}$表明我们在$\boldsymbol{\theta}=\boldsymbol{\theta}^*$处展开,这里的$\boldsymbol{\theta}^*$就是我们要寻找的目标$\eqref{eq:loss}$,在此处模型的梯度为零,从而上式可以简化成
\begin{equation}\boldsymbol{g}_{\boldsymbol{\theta}} \approx \boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}(\boldsymbol{\theta} - \boldsymbol{\theta}^*)\end{equation}
于是
\begin{equation}\boldsymbol{g}_{\boldsymbol{\theta}}\boldsymbol{g}_{\boldsymbol{\theta}}^{\top} \approx \boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}(\boldsymbol{\theta} - \boldsymbol{\theta}^*)(\boldsymbol{\theta} - \boldsymbol{\theta}^*)^{\top}\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}^{\top}\end{equation}
假设训练进入“正轨”后,模型将会长期处于围绕着$\boldsymbol{\theta}^*$“打转”、缓慢且螺旋地收敛的状态,那么一定程度上我们可以将$\boldsymbol{\theta} - \boldsymbol{\theta}^*$视为正态分布$\mathcal{N}(\boldsymbol{0},\sigma^2\boldsymbol{I})$的随机变量,那么
\begin{equation}\mathbb{E}[\boldsymbol{g}_{\boldsymbol{\theta}}\boldsymbol{g}_{\boldsymbol{\theta}}^{\top}] \approx \boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}\mathbb{E}[(\boldsymbol{\theta} - \boldsymbol{\theta}^*)(\boldsymbol{\theta} - \boldsymbol{\theta}^*)^{\top}]\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}^{\top} = \sigma^2\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}^{\top}\label{eq:hessian-2}\end{equation}
假设Hessian矩阵是对角阵,那么上式我们可以只保留对角线元素
\begin{equation}\text{diag}(\mathbb{E}[\boldsymbol{g}_{\boldsymbol{\theta}}\otimes\boldsymbol{g}_{\boldsymbol{\theta}}]) \approx \sigma^2\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}^2\quad\Rightarrow\quad \boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*} = \frac{1}{\sigma}\text{diag}(\sqrt{\mathbb{E}[\boldsymbol{g}_{\boldsymbol{\theta}}\otimes\boldsymbol{g}_{\boldsymbol{\theta}}]})\end{equation}
是不是有点相似了?Adam的$\hat{\boldsymbol{v}}_t$是对梯度平方的滑动平均,它可以看作在近似$\mathbb{E}[\boldsymbol{g}_{\boldsymbol{\theta}}\otimes\boldsymbol{g}_{\boldsymbol{\theta}}]$。最后我们再假设$\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}_t}$相比$\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}^*}$变化不大,就可以得到$\eta_t^{-1}\text{diag}(\sqrt{\hat{\boldsymbol{v}}_t})$是$\boldsymbol{\mathcal{H}}_t$近似的结论。
这也可以解释为什么Adam的$\beta_2$通常都大于$\beta_1$。为了更准确地估计Hessian,$\hat{\boldsymbol{v}}_t$的滑动平均应该尽可能“长期”(接近均匀平均),所以$\beta_2$应该要很接近于1;而动量$\hat{\boldsymbol{m}}_t$是梯度的滑动平均,如果梯度的平均过于长期的话,那么结果将会接近$\boldsymbol{g}_{\boldsymbol{\theta}^*}=\boldsymbol{0}$,这反而不好,因此动量的滑动平均要更局部些。
相关工作 #
对于比较了解Hessian矩阵理论的读者,看到上述结论后的第一反应也许不是熟悉而是疑惑,这是因为Hessian矩阵的一个经典近似是Jacobi矩阵(类似梯度)的外积,而这里的Hessian近似则是梯度外积的平方根,两者差了个根号。
具体来说,我们以平方误差损失为例
\begin{equation}\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{2}\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\Vert \boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x})\Vert^2]\label{eq:loss-2}\end{equation}
我们在$\boldsymbol{\theta}_t$处展开,有$\boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x})\approx \boldsymbol{f}_{\boldsymbol{\theta}_t}(\boldsymbol{x}) + \boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}_t}^{\top} (\boldsymbol{\theta} - \boldsymbol{\theta}_t)$,其中$\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}_t}=\nabla_{\boldsymbol{\theta}_t} \boldsymbol{f}_{\boldsymbol{\theta}_t}(\boldsymbol{x})$是Jacobi矩阵,代入上式得到
\begin{equation}\mathcal{L}(\boldsymbol{\theta}) \approx \frac{1}{2}\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\Vert \boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}_t}(\boldsymbol{x}) - \boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}_t}^{\top} (\boldsymbol{\theta} - \boldsymbol{\theta}_t)\Vert^2]\end{equation}
经过简化后的上式只是关于$\boldsymbol{\theta}$的二次型,因此可以直接写出它的Hessian矩阵,结果是
\begin{equation}\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}_t} \approx \mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}_t}\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}_t}^{\top}]\end{equation}
这就是基于Jacobi矩阵外积的Hessian近似,它是“Gauss–Newton法”的理论基础。当然,$\boldsymbol{\mathcal{J}}$还不是$\boldsymbol{g}$,我们要试图将结果跟$\mathcal{g}$联系起来。对式$\eqref{eq:loss-2}$直接求导得
\begin{equation}\boldsymbol{g}_{\boldsymbol{\theta}} = \mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))]\end{equation}
于是
\begin{equation}\begin{aligned}
\boldsymbol{g}_{\boldsymbol{\theta}} \boldsymbol{g}_{\boldsymbol{\theta}}^{\top} =&\, \big(\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))]\big)\big(\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))]\big)^{\top} \\[5pt]
=&\, \big(\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))]\big)\big(\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}[(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))^{\top}\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}^{\top}]\big) \\[5pt]
\approx&\, \mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}\big[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))^{\top}\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}^{\top}\big] \\[5pt]
\approx&\, \mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}\Big[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}\big[(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))(\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x}))^{\top}\big]\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}^{\top}\Big] \\[5pt]
\end{aligned}\end{equation}
这里两个约等号其实没有太多道理,可以勉强看成是平均场近似,而$\boldsymbol{y} - \boldsymbol{f}_{\boldsymbol{\theta}}(\boldsymbol{x})$是回归预测的残差,我们通常假设它服从$\mathcal{N}(\boldsymbol{0},\sigma^2\boldsymbol{I})$,因此有
\begin{equation}\boldsymbol{g}_{\boldsymbol{\theta}} \boldsymbol{g}_{\boldsymbol{\theta}}^{\top} \approx \sigma^2\mathbb{E}_{(\boldsymbol{x},\boldsymbol{y})\sim\mathcal{D}}\big[\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}\boldsymbol{\mathcal{J}}_{\boldsymbol{\theta}}^{\top}\big] \approx \sigma^2 \boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}_t}\label{eq:hessian-t}\end{equation}
这就揭示了$\boldsymbol{\mathcal{H}}_{\boldsymbol{\theta}_t}$与$\boldsymbol{g}_{\boldsymbol{\theta}} \boldsymbol{g}_{\boldsymbol{\theta}}^{\top}$的联系。对比上一节的式$\eqref{eq:hessian-2}$,可以发现表面上刚好差了个平方。
看推导过程,两个结果似乎都没明显错误,那怎么理解这种不一致性呢?我们可以这样理解:式$\eqref{eq:hessian-t}$给出的是$t$时刻的Hessian近似,属于“瞬时近似”,而式$\eqref{eq:hessian-2}$则是时间步的“长期平均”结果,长期的平均作用抵销了一部分强度(但理论上也会使得估计更准确),从而需要多开一个平方根。
类似的效应也出现在《生成扩散模型漫谈(五):一般框架之SDE篇》介绍的SDE中,SDE的噪声项强度需要比非噪声项高半阶,同样是因为噪声项在长期平均之下会抵消,所以噪声需要更高阶才能在最终的结果中体现出噪声的作用。
文章小结 #
本文介绍了从Newton法和Hessian近似看待Adam等自适应学习率优化器的一个视角,并讨论了Hessian近似的相关结果。
转载到请包括本文地址:https://kexue.fm/archives/10588
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Nov. 29, 2024). 《从Hessian近似看自适应学习率优化器 》[Blog post]. Retrieved from https://kexue.fm/archives/10588
@online{kexuefm-10588,
title={从Hessian近似看自适应学习率优化器},
author={苏剑林},
year={2024},
month={Nov},
url={\url{https://kexue.fm/archives/10588}},
}
最近评论