指数梯度下降 + 元学习 = 自适应学习率
By 苏剑林 | 2022-03-03 | 29375位读者 |前两天刷到了Google的一篇论文《Step-size Adaptation Using Exponentiated Gradient Updates》,在其中学到了一些新的概念,所以在此记录分享一下。主要的内容有两个,一是非负优化的指数梯度下降,二是基于元学习思想的学习率调整算法,两者都颇有意思,有兴趣的读者也可以了解一下。
指数梯度下降 #
梯度下降大家可能听说得多了,指的是对于无约束函数$\mathcal{L}(\boldsymbol{\theta})$的最小化,我们用如下格式进行更新:
\begin{equation}\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}_t)\end{equation}
其中$\eta$是学习率。然而很多任务并非总是无约束的,对于最简单的非负约束,我们可以改为如下格式更新:
\begin{equation}\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t \odot \exp\left(- \eta\nabla_{\boldsymbol{\theta}}\mathcal{L}(\boldsymbol{\theta}_t)\right)\label{eq:egd}\end{equation}
这里的$\odot$是逐位对应相乘(Hadamard积)。容易看到,只要初始化的$\boldsymbol{\theta}_0$是非负的,那么在整个更新过程中$\boldsymbol{\theta}_t$都会保持非负,这就是用于非负约束优化的“指数梯度下降”。
怎么理解这个“指数梯度下降”呢?也不难,转化为无约束的情形进行推导就行了。如果$\boldsymbol{\theta}$是非负的,那么$\boldsymbol{\varphi}=\log\boldsymbol{\theta}$就是可正可负的了,因此可以设$\boldsymbol{\theta}=e^{\boldsymbol{\varphi}}$转化为关于$\boldsymbol{\varphi}$的无约束优化问题,继而就可以用梯度下降解决:
\begin{equation}\boldsymbol{\varphi}_{t+1} = \boldsymbol{\varphi}_t - \eta\nabla_{\boldsymbol{\varphi}}\mathcal{L}(e^{\boldsymbol{\varphi}_t}) = \boldsymbol{\varphi}_t - \eta e^{\boldsymbol{\varphi}_t}\odot\nabla_{e^{\boldsymbol{\varphi}}}\mathcal{L}(e^{\boldsymbol{\varphi}_t})\end{equation}
我们认为梯度的$e^{\boldsymbol{\varphi}_t}\odot$这部分只起到了调节学习率的作用,所以它不是本质重要的,我们将它舍去得到
\begin{equation}\boldsymbol{\varphi}_{t+1} = \boldsymbol{\varphi}_t - \eta \nabla_{e^{\boldsymbol{\varphi}}}\mathcal{L}(e^{\boldsymbol{\varphi}_t})\end{equation}
两边取指数得
\begin{equation}e^{\boldsymbol{\varphi}_{t+1}} = e^{\boldsymbol{\varphi}_t}\odot\exp\left( - \eta \nabla_{e^{\boldsymbol{\varphi}}}\mathcal{L}(e^{\boldsymbol{\varphi}_t})\right)\end{equation}
换回$\boldsymbol{\theta}=e^{\boldsymbol{\varphi}}$就得到式$\eqref{eq:egd}$。
元学习调学习率 #
对于元学习(Meta Learning),可能多数读者都跟笔者一样听得多,但几乎没接触过。简单来说,普通机器学习跟元学习的关系,就像是数学中“函数”跟“泛函”的关系,泛函是“函数的函数”,元学习则是“学习如何学习(Learning How to Learn)”,也就是说它是关于“学习”本身的方法论,比如接下来要介绍的,就是“用梯度下降去调整梯度下降”。
我们从一般的梯度下降出发,记目标函数$\mathcal{L}$的梯度为$\boldsymbol{g}$,那么更新公式为
\begin{equation}\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta\boldsymbol{g}_t\end{equation}
我们希望给每个分量都调节一下学习率,所以我们引入跟参数一样大小的非负变量$\boldsymbol{\nu}$,修改更新公式为
\begin{equation}\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta\boldsymbol{\nu}_{t+1}\odot\boldsymbol{g}_t\label{eq:update}\end{equation}
那么,$\boldsymbol{\nu}$要按照什么规则迭代呢?记住我们最终的目的是最小化$\mathcal{L}$,所以$\boldsymbol{\nu}$的更新规则应该也要是梯度下降,而这里$\boldsymbol{\nu}$要求是非负的,所以我们用指数梯度下降:
\begin{equation}\boldsymbol{\nu}_{t+1} = \boldsymbol{\nu}_t \odot\exp\left(- \gamma\nabla_{\boldsymbol{\nu}_t}\mathcal{L}\right)\label{eq:update-nu}\end{equation}
注意$\mathcal{L}$本来只是$\boldsymbol{\theta}$的函数,但根据$\eqref{eq:update}$,在$t$时刻我们有$\boldsymbol{\theta}_t = \boldsymbol{\theta}_{t-1} - \eta\boldsymbol{\nu}_t\odot\boldsymbol{g}_{t-1}$,所以根据链式法则有
\begin{equation}\nabla_{\boldsymbol{\nu}_t}\mathcal{L} = -\eta\boldsymbol{g}_{t-1} \odot\nabla_{\boldsymbol{\theta}_t}\mathcal{L}= -\eta\boldsymbol{g}_{t-1} \odot\boldsymbol{g}_t\end{equation}
代入到$\nu$的更新公式$\eqref{eq:update-nu}$,得到
\begin{equation}\boldsymbol{\nu}_{t+1} = \boldsymbol{\nu}_t \odot\exp\left( \gamma\eta\boldsymbol{g}_{t-1} \odot\boldsymbol{g}_t\right)\end{equation}
将$\gamma\eta$合成一个参数$\gamma$,于是整个模型的更新公式是:
\begin{equation}\begin{aligned}&\boldsymbol{\nu}_{t+1} = \boldsymbol{\nu}_t \odot\exp\left( \gamma\boldsymbol{g}_{t-1} \odot\boldsymbol{g}_t\right) \\
&\boldsymbol{\theta}_{t+1} = \boldsymbol{\theta}_t - \eta\boldsymbol{\nu}_{t+1}\odot\boldsymbol{g}_t\end{aligned}\end{equation}
如果$\boldsymbol{\nu}$初始化为全1,那么将有
\begin{equation}\boldsymbol{\nu}_{t+1} = \exp\left(\gamma\sum_{k=1}^t\boldsymbol{g}_{k-1} \odot\boldsymbol{g}_k\right)\end{equation}
可以看到,该方法的学习率调节思路是:如果某分量相邻两步的梯度经常同号,那么对应项的累加结果就是正的,意味着我们可以适当扩大一下学习率;如果相邻两步的梯度经常异号,那么对应项的累加结果很可能是负的,意味着我们可以适当缩小一下学习率。
注意这跟Adam调学习率的思想是不一样的,Adam调节学习率的思想是如果某个分量的梯度长时间很小,那么就意味着该参数可能没学好,所以尝试放大它的学习率。两者也算是各有各的道理吧。
简单做个小结 #
本文主要对“指数梯度下降”和“元学习调学习率”两个概念做了简单笔记,“指数梯度下降”是非负约束优化的一个简单有效的方案,而“元学习调学习率”则是元学习的一个简单易懂的应用。其中在介绍“元学习调学习率”时笔者做了一些简化,相比原论文的形式更为简单一些,但思想是一致的。
转载到请包括本文地址:https://kexue.fm/archives/8968
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Mar. 03, 2022). 《指数梯度下降 + 元学习 = 自适应学习率 》[Blog post]. Retrieved from https://kexue.fm/archives/8968
@online{kexuefm-8968,
title={指数梯度下降 + 元学习 = 自适应学习率},
author={苏剑林},
year={2022},
month={Mar},
url={\url{https://kexue.fm/archives/8968}},
}
March 8th, 2022
居然看得懂。。最后把v初始化为1得出公式(12),再由该公式得出结论:“如果某分量相邻两步的梯度经常同号,那么对应项的累加结果就是正的,意味着我们可以适当扩大一下学习率;...”真直观。
March 8th, 2022
公式(3)这里可能有点小笔误,应该是$\begin{equation}\boldsymbol{\varphi}_{t+1} = \boldsymbol{\varphi}_t - \eta\nabla_{\boldsymbol{\varphi}}\mathcal{L}({\boldsymbol{\varphi}_t}) \end{equation}$
这里没有笔误,就是$\mathcal{L}(e^{\boldsymbol{\varphi}_t})$,不是$\mathcal{L}(\boldsymbol{\varphi}_t)$。
别忘了这里是接着上面讲的,上面已经讲了$\mathcal{L}$是$\boldsymbol{\theta}$的函数,而$\boldsymbol{\theta}=e^{\boldsymbol{\varphi}}$,所以显然是$\mathcal{L}(e^{\boldsymbol{\varphi}_t})$,不是随便记的。
为什么我感觉苏神写错了?
那就试着修正你的感觉。
如果把目标函数看成L(e^x),即自变量是x而不是e^x,那就没问题了。
“看成”的意思是“本来不应该这样子,但可以强制视为这样子”。$L(e^x)$的自变量显然是$x$,不需要“看成”;如果你认为自变量是$e^x$,那才需要“看成”。
December 31st, 2023
大佬好,提供两个相关文献哈,Online Learning Rate Adaptation with Hypergradient Descent 和 Convergence Analysis of an Adaptive Method of Gradient Descent(Marinez的毕业论文)是最早两个对学习率进行学习的文章。
另外我个人在语言模型预训练中尝试了Adam+这个自动调节学习率的方法,结果并不成功,学习率会很快衰减到0,也就是前一步梯度和这步梯度的内积总是小于0,不知道大佬有没有实际用过这个方法?
感谢推荐。
很抱歉我没实测过这个方法,另外这个方法本质上是只适用于SGD(还要不带动量),其他优化器可用性不知道怎样~