从谱范数梯度到新式权重衰减的思考
By 苏剑林 | 2024-12-25 | 8603位读者 |在文章《Muon优化器赏析:从向量到矩阵的本质跨越》中,我们介绍了一个名为“Muon”的新优化器,其中一个理解视角是作为谱范数正则下的最速梯度下降,这似乎揭示了矩阵参数的更本质的优化方向。众所周知,对于矩阵参数我们经常也会加权重衰减(Weight Decay),它可以理解为$F$范数平方的梯度,那么从Muon的视角看,通过谱范数平方的梯度来构建新的权重衰减,会不会能起到更好的效果呢?
那么问题来了,谱范数的梯度或者说导数长啥样呢?用它来设计的新权重衰减又是什么样的?接下来我们围绕这些问题展开。
基础回顾 #
谱范数(Spectral Norm),又称“$2$范数”,是最常用的矩阵范数之一,相比更简单的$F$范数(Frobenius Norm),它往往能揭示一些与矩阵乘法相关的更本质的信号,这是因为它定义上就跟矩阵乘法相关:对于矩阵参数$\boldsymbol{W}\in\mathbb{R}^{n\times m}$,它的谱范数定义为
\begin{equation}\Vert\boldsymbol{W}\Vert_2 \triangleq \max_{\Vert\boldsymbol{x}\Vert=1} \Vert\boldsymbol{W}\boldsymbol{x}\Vert\end{equation}
这里$\boldsymbol{x}\in\mathbb{R}^m$是列向量,右端的$\Vert\Vert$是向量的模长(欧氏范数)。换个角度看,谱范数就是使得下面不等式对$\forall \boldsymbol{x}\in\mathbb{R}^m$恒成立的最小常数$C$:
\begin{equation}\Vert\boldsymbol{W}\boldsymbol{x}\Vert \leq C\Vert\boldsymbol{x}\Vert\end{equation}
不难证明,当$C$取$F$范数$\Vert W\Vert_F$时,上式也是恒成立的,所以可以写出$\Vert \boldsymbol{W}\Vert_2\leq \Vert \boldsymbol{W}\Vert_F$(因为$\Vert \boldsymbol{W}\Vert_F$只是让上式恒成立的其中一个$C$,而$\Vert \boldsymbol{W}\Vert_2$则是最小的那个$C$)。这个结论也表明,如果我们想要控制输出的幅度,以谱范数作为正则项要比$F$范数更为精准。
早在6年前的《深度学习中的Lipschitz约束:泛化与生成模型》中,我们就讨论过谱范数,当时的应用场景有两个:一是WGAN对判别器明确提出了Lipschitz约束,而实现方式之一就是基于谱范数的归一化;二是有一些工作表明,谱范数作为正则项,相比$F$范数正则有更好的性能。
梯度推导 #
现在让我们进入正题,尝试推导谱范数的梯度$\nabla_{\boldsymbol{W}} \Vert\boldsymbol{W}\Vert_2$。我们知道,谱范数在数值上等于它的最大奇异值,对此我们在《低秩近似之路(二):SVD》的“矩阵范数”一节有过证明。这意味着,如果$\boldsymbol{W}$可以SVD为$\sum\limits_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}$,那么
\begin{equation}\Vert\boldsymbol{W}\Vert_2 = \sigma_1 = \boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1\end{equation}
其中$\sigma_1 \geq \sigma_2 \geq \cdots \geq \sigma_{\min(n,m)} \geq 0$是$\boldsymbol{W}$的奇异值。对两边求微分,我们得到
\begin{equation}d\Vert\boldsymbol{W}\Vert_2 = d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 + \boldsymbol{u}_1^{\top}d\boldsymbol{W}\boldsymbol{v}_1 + \boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1\end{equation}
留意到
\begin{equation}d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1 = d\boldsymbol{u}_1^{\top}\sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}\boldsymbol{v}_1 = d\boldsymbol{u}_1^{\top}\sigma_1 \boldsymbol{u}_1 = \frac{1}{2}\sigma_1 d(\Vert\boldsymbol{u}_1\Vert^2)=0\end{equation}
同理$\boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1=0$,所以
\begin{equation}d\Vert\boldsymbol{W}\Vert_2 = \boldsymbol{u}_1^{\top}d\boldsymbol{W}\boldsymbol{v}_1 = \text{Tr}((\boldsymbol{u}_1 \boldsymbol{v}_1^{\top})^{\top} d\boldsymbol{W}) \quad\Rightarrow\quad \nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2 = \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}\end{equation}
注意,这个证明过程有一个关键条件是$\sigma_1 > \sigma_2$,因为如果$\sigma_1=\sigma_2$的话,$\Vert\boldsymbol{W}\Vert_2$既可以表示成$\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1$又可以表示成$\boldsymbol{u}_2^{\top}\boldsymbol{W}\boldsymbol{v}_2$,用同样方法求出的梯度分别是$\boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$和$\boldsymbol{u}_2 \boldsymbol{v}_2^{\top}$,结果不唯一意味着梯度不存在。当然,从实践角度看,两个数完全相等的概率是很小的,因此可以忽略这一点。
(注:这里的证明过程参考了Stack Exchange上的回答,但该回答里面没有证明$d\boldsymbol{u}_1^{\top}\boldsymbol{W}\boldsymbol{v}_1=0$和$\boldsymbol{u}_1^{\top}\boldsymbol{W}d\boldsymbol{v}_1=0$,这部分由笔者补充完整。)
权重衰减 #
根据这个结果以及链式法则,我们有
\begin{equation}\nabla_{\boldsymbol{W}}\left(\frac{1}{2}\Vert\boldsymbol{W}\Vert_2^2\right) = \Vert\boldsymbol{W}\Vert_2\nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2 = \sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}\label{eq:grad-2-2}\end{equation}
对比$F$范数下的结果:
\begin{equation}\nabla_{\boldsymbol{W}}\left(\frac{1}{2}\Vert\boldsymbol{W}\Vert_F^2\right) = \boldsymbol{W} = \sum_{i=1}^{\min(n,m)}\sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top}\end{equation}
这样对比着看就很清晰了:$F$范数平方作为正则项所得出的权重衰减,同时惩罚全体奇异值;而谱范数平方对应的权重衰减,只惩罚最大奇异值。如果我们目的是压缩输出的大小,那么压缩最大奇异值是“刚刚好”的做法,压缩全体奇异值虽然可能达到相近的目的,但同时也可能压缩参数的表达能力。
根据“Eckart-Young-Mirsky定理”,式$\eqref{eq:grad-2-2}$最右侧的结果还有一个含义,就是$\boldsymbol{W}$矩阵的“最优1秩近似”。也就是说,谱范数的权重衰减将每一步减去它自身的操作,改为每一步减去它的最优1秩近似,弱化了惩罚力度,当然某种程度上也让惩罚更加“直击本质”。
数值计算 #
对于实践来说,最关键的问题来了:怎么计算$\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$呢?SVD当然是最简单直接的方案,但计算复杂度无疑也是最高的,我们必须找到更高效的计算途径。
不失一般性,设$n\geq m$。首先注意到
\begin{equation}\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top} = \sum_{i=1}^m\sigma_i \boldsymbol{u}_i \boldsymbol{v}_i^{\top} \boldsymbol{v}_1 \boldsymbol{v}_1^{\top} = \boldsymbol{W}\boldsymbol{v}_1 \boldsymbol{v}_1^{\top}\end{equation}
由此可见计算$\sigma_1 \boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$只需要知道$\boldsymbol{v}_1$,然后根据我们在《低秩近似之路(二):SVD》中的讨论,$\boldsymbol{v}_1$实际上是矩阵$\boldsymbol{W}^{\top}\boldsymbol{W}$的最大特征值对应的特征向量。这样一来,我们便将问题从一般矩阵$\boldsymbol{W}$的SVD转化成了实对称矩阵$\boldsymbol{W}^{\top}\boldsymbol{W}$的特征值分解,这其实已经降低复杂度了,因为特征值分解通常要比SVD明显快。
如果还觉得慢,那么我们就需要请出很多特征值分解算法背后的原理——“幂迭代(Power Iteration)”:
当$\sigma_1 > \sigma_2$时,迭代 \begin{equation}\boldsymbol{x}_{t+1} = \frac{\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}_t}{\Vert\boldsymbol{W}^{\top}\boldsymbol{W}\boldsymbol{x}_t\Vert}\end{equation} 以$(\sigma_2/\sigma_1)^{2t}$的速度收敛至$\boldsymbol{v}_1$。
幂迭代每步只需要算两次“矩阵-向量”乘法,复杂度是$\mathcal{O}(nm)$,$t$步迭代的总复杂度是$\mathcal{O}(tnm)$,非常理想,缺点是$\sigma_1,\sigma_2$接近时收敛会比较慢。但幂迭代的实际表现往往比理论想象更好用,早期很多工作甚至只迭代一次就得到不错的效果,因为$\sigma_1,\sigma_2$接近表明两者及其特征向量一定程度上可替换,而幂迭代即便没完全收敛,得到的也是两者特征向量的一个平均,这也完全够用了。
迭代证明 #
这一节我们来完成幂迭代的证明。不难看出,幂迭代可以等价地写成
\begin{equation}\lim_{t\to\infty} \frac{(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0}{\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert} = \boldsymbol{v}_1\end{equation}
为了证明这个极限,我们从$\boldsymbol{W}=\sum\limits_{i=1}^m\sigma_i \boldsymbol{u}_i\boldsymbol{v}_i^{\top}$出发,代入计算可得
\begin{equation}\boldsymbol{W}^{\top}\boldsymbol{W} = \sum_{i=1}^m\sigma_i^2 \boldsymbol{v}_i\boldsymbol{v}_i^{\top},\qquad(\boldsymbol{W}^{\top}\boldsymbol{W})^t = \sum_{i=1}^m\sigma_i^{2t} \boldsymbol{v}_i\boldsymbol{v}_i^{\top}\end{equation}
由于$\boldsymbol{v}_1,\boldsymbol{v}_2,\cdots,\boldsymbol{v}_m$是$\mathbb{R}^m$的一组标准正交基,所以$\boldsymbol{x}_0$可以写成$\sum\limits_{j=1}^m c_j \boldsymbol{v}_j$,于是我们有
\begin{equation}(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0 = \sum_{i=1}^m\sigma_i^{2t} \boldsymbol{v}_i\boldsymbol{v}_i^{\top}\sum_{j=1}^m c_j \boldsymbol{v}_j = \sum_{i=1}^m\sum_{j=1}^m c_j\sigma_i^{2t} \boldsymbol{v}_i\underbrace{\boldsymbol{v}_i^{\top} \boldsymbol{v}_j}_{=\delta_{i,j}} = \sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i\end{equation}
以及
\begin{equation}\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert = \left\Vert \sum_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i\right\Vert = \sqrt{\sum_{i=1}^m c_i^2\sigma_i^{4t}}\end{equation}
由于随机初始化的缘故,$c_1=0$的概率是非常小的,所以我们可以认为$c_1\neq 0$,那么
\begin{equation}\frac{(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0}{\Vert(\boldsymbol{W}^{\top}\boldsymbol{W})^t \boldsymbol{x}_0\Vert} = \frac{\sum\limits_{i=1}^m c_i\sigma_i^{2t} \boldsymbol{v}_i}{\sqrt{\sum\limits_{i=1}^m c_i^2\sigma_i^{4t}}} = \frac{\boldsymbol{v}_1 + \sum\limits_{i=2}^m (c_i/c_1)(\sigma_i/\sigma_1)^{2t} \boldsymbol{v}_i}{\sqrt{1 + \sum\limits_{i=2}^m (c_i/c_1)^2(\sigma_i/\sigma_1)^{4t}}}\end{equation}
当$\sigma_1 > \sigma_2$时,所有的$\sigma_i/\sigma_1(i\geq 2)$都小于1,因此当$t\to \infty$时对应项都变成了0,最后的极限是$\boldsymbol{v}_1$。
相关工作 #
最早提出谱范数正则的论文,应该是2017年的《Spectral Norm Regularization for Improving the Generalizability of Deep Learning》,里边对比了权重衰减、对抗训练、谱范数正则等方法,发现谱范数正则在泛化性能方面表现最好。
论文当时的做法,并不是像本文一样求出$\nabla_{\boldsymbol{W}}\Vert\boldsymbol{W}\Vert_2^2 = 2\sigma_1\boldsymbol{u}_1 \boldsymbol{v}_1^{\top}$,而是直接通过幂迭代来估计$\Vert\boldsymbol{W}\Vert_2$,然后将$\Vert\boldsymbol{W}\Vert_2^2$加权到损失函数中,让优化器自己去求梯度,这样做效率上稍差一些,并且也不好以权重衰减的形式跟优化器解耦开来。本文的做法相对来说更加灵活一些,允许我们像AdamW一样,将权重衰减独立于主损失函数的优化之外。
当然,从今天LLM的视角来看,当初的这些实验最大问题就是规模都太小了,很难有足够的说服力,不过鉴于谱范数的Muon优化器“珠玉在前”,笔者认为还是值得重新思考和尝试一下谱范数权重衰减。当然,不管是$F$范数还是谱范数的权重衰减,这些面向“泛化”的技术往往也有一些运气成份在里边,大家平常心期待就好。
个人在语言模型的初步实验结果显示,Loss层面可能会有微弱的提升(希望不是幻觉,当然再不济也没有出现变差的现象)。实验过程就是用幂迭代求出$\boldsymbol{v}_1$的近似值(初始化为全一向量,迭代10次),然后将原来的权重衰减$-\lambda \boldsymbol{W}$改为$-\lambda \boldsymbol{W}\boldsymbol{v}_1\boldsymbol{v}_1^{\top}$,$\lambda$的取值不做改变。
文章小结 #
本文推导了谱范数的梯度,由此导出了一种新的权重衰减,并分享了笔者对它的思考。
转载到请包括本文地址:https://kexue.fm/archives/10648
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Dec. 25, 2024). 《从谱范数梯度到新式权重衰减的思考 》[Blog post]. Retrieved from https://kexue.fm/archives/10648
@online{kexuefm-10648,
title={从谱范数梯度到新式权重衰减的思考},
author={苏剑林},
year={2024},
month={Dec},
url={\url{https://kexue.fm/archives/10648}},
}
最近评论