msign算子的Newton-Schulz迭代(下)
By 苏剑林 | 2025-06-05 | 4164位读者 |在上文《msign算子的Newton-Schulz迭代(上)》中,我们试图为$\mathop{\text{msign}}$算子寻找更好的Newton-Schulz迭代,以期在有限迭代步数内能达到尽可能高的近似程度,这一过程又可以转化为标量函数$\mathop{\text{sign}}(x)$寻找同样形式的多项式迭代。当时,我们的求解思路是用Adam优化器端到端地求一个局部最优解,虽然有效但稍显粗暴。
而在几天前,arXiv新出了一篇论文《The Polar Express: Optimal Matrix Sign Methods and Their Application to the Muon Algorithm》,作者运用了一系列精妙的数学结论,以优雅且硬核的方式给出了更漂亮的答案。本文让我们一起欣赏和学习一番这篇精彩的论文。
问题描述 #
相关背景和转化过程我们就不再重复了,直接给出我们要求解的问题是
\begin{equation}\mathop{\text{argmin}}_f d(f(x),1)\end{equation}
其中$f = f_T \circ \dots \circ f_2 \circ f_1$,$\circ$代表函数的复合,$f_t(x)$是关于$x$的奇多项式(只包含$x$的奇数次幂),而$d(f(x),1)$是衡量函数$f(x)$与$1$距离的某个指标。上一篇文章中,我们是在$[0,1]$内均匀选有限个点,取最大的$k$个$|f(x)-1|$的平均值作为指标。本文则直接取区间内的$|f(x)-1|$最大值作为指标,即
\begin{equation}\mathop{\text{argmin}}_f \max_{x\in[l,u]} |f(x) - 1| \label{eq:opt}\end{equation}
其中$[l,u]\subset [0,1]$。要注意,此时$u$可以直接取1,但$l$不能取0,因为$f(0)$始终是0,这意味着上式始终大于等于1,无法收敛,所以$l$只能选一个很接近于0的数。按照上一篇文章的分析,为了普适性,我们应该要照顾到$0.001$大小的奇异值,因此可以考虑$l=0.001$。
在开始分析之前,我们先简单解释一下论文标题中“Polar”一词的含义,它其实代表了矩阵的“极分解(Polar Decomposition)”:
极分解(Polar Decomposition) 对于一个方阵$\boldsymbol{M}\in\mathbb{R}^{n\times n}$,它的极分解为$\boldsymbol{M}=\boldsymbol{Q}\boldsymbol{S}$,其中$\boldsymbol{Q}$是一个正交矩阵,而$\boldsymbol{S}$是一个半正定矩阵。
而如果$\boldsymbol{M}$的SVD为$\boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$,那么正好有
\begin{equation}\boldsymbol{M} = \boldsymbol{U}\boldsymbol{\Sigma}\boldsymbol{V}^{\top} = (\boldsymbol{U}\boldsymbol{V}^{\top})(\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top})\end{equation}
即$\boldsymbol{Q}=\boldsymbol{U}\boldsymbol{V}^{\top},\boldsymbol{S}=\boldsymbol{V}\boldsymbol{\Sigma}\boldsymbol{V}^{\top}$就是极分解的一个答案。而我们知道,当$\boldsymbol{M}$是满秩矩阵时,$\boldsymbol{U}\boldsymbol{V}^{\top}$正好是$\mathop{\text{msgin}}(\boldsymbol{M})$。这也就是为什么$\mathop{\text{msgin}}$会跟“Polar”联系起来了,因为求出它之后就可以得到矩阵的“Polar Decomposition”,换言之极分解的本质难度就是计算$\mathop{\text{msgin}}$,这跟Muon异曲同工。
贪心即可 #
言归正传。对于问题$\eqref{eq:opt}$,原论文得出的第一个结论,也应该是全论文最核心的结论是:它的贪心解正好是它的全局最优解!用公式来说,它是指问题$\eqref{eq:opt}$的求解,可以转化为:
\begin{equation}\begin{gathered}
f^* = f_T^* \circ \dots \circ f_2^* \circ f_1^* \\[12pt]
f_1^* = \mathop{\text{argmin}}_{f_1} \max_{x\in[l_1,u_1]} |f_1(x) - 1| \\
f_2^* = \mathop{\text{argmin}}_{f_2} \max_{x\in[l_2,u_2]} |f_2(x) - 1| \\
\vdots \\
f_T^* = \mathop{\text{argmin}}_{f_T} \max_{x\in[l_T,u_T]} |f_T(x) - 1| \\[24pt]
l_1 = l,\quad u_1 = u, \\[8pt]
l_{t+1} = \min_{x\in[l_t,u_t]} f_t^*(x),\quad u_{t+1} = \max_{x\in[l_t,u_t]} f_t^*(x)
\end{gathered}\end{equation}
这个结论我相信会出乎不少读者意料,笔者首次看到时也是相当惊讶,并为之拍案叫绝。它不仅大大降低了求解难度,将原本$T$步的复合函数求解,转化为$T=1$的单个多项式逐步求解,还允许我们将解逐步往前推,并一直保持最优性(即$T+1$步的最优解,只需要在$T$步最优解基础上多算一步,而不用从头计算)。
值得指出的是,这个结论是允许每个$f_t$取不同阶的(这里的“阶”指多项式最高次数),比如$f_1$可以是3阶的,$f_2$可以是5阶的,等等,但“贪心解正好是全局最优解”的结论依然不变。不过简单起见,下面我们还是让所有$f_t$都保持同阶,并且主要考虑3阶和5阶的结果。
上述结论的完整证明略微有点复杂,我们把它放到最后,先完成基于该结论的后续操作。
等值振荡 #
既然我们已经把原问题转化为求贪心解,那么现在只需要专注于求解
\begin{equation}\mathop{\text{argmin}}_{f_t} \max_{x\in[l_t,u_t]} |f_t(x) - 1| \label{eq:local}\end{equation}
为了求解上式,我们需要先了解在《等值振荡定理:最优多项式逼近的充要条件》介绍的关于奇多项式的“等值振荡定理(Equioscillation Theorem)”:
等值振荡定理-奇 设$f(x)$是不超过$2n+1$阶的奇多项式,$g(x)$是区间$[a,b]\subset (0,\infty)$上的连续函数,那么 \begin{equation}f^* = \mathop{\text{argmin}}_f \max_{x\in[a,b]} |f(x) - g(x)|\end{equation} 的充要条件是存在$a\leq x_0 < x_1 < \cdots < x_{n+1} \leq b$以及$\sigma\in\{0,1\}$,使得 \begin{equation}f^*(x_k) - g(x_k) = (-1)^{k+\sigma} \max_{x\in[a,b]} |f^*(x) - g(x)|\end{equation}
现在我们要求的是$f_t$,目标$g$则是恒等于1,等值振荡定理告诉我们$|f_t^*(x)-1|$在$[l_t,u_t]$至少$n+2$次达到最大误差(记为$\mathcal{E}$)。不难发现,$|f_t^*(x)-1|$的最大值点只能是边界点或者是$f_t^*(x)$的极值点,而一个$2n+1$阶奇多项式在$(0,\infty)$至多有$n$个极值点,所以为了“凑”够$n+2$个,我们“不得已”要把边界点补上,这就确定了$x_0 = l_t, x_{n+1}=u_t$,而$x_1,\cdots,x_n$则是$\frac{d}{dx}f_t^*(x)$的零点。
此外,由于目标函数是$1$,所以$f_t^*(x)$在$x=0$处的斜率大于零,所以$l_t$只能是$f_t^*(x)$的最小值点,因此$\sigma=1$。综合这些结果,这样我们实际上要解方程组:
\begin{equation}f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + (-1)^n \mathcal{E},\quad f_t(x_k) = 1 + (-1)^{i+1}\mathcal{E}, \quad f_t'(x_i) = 0\end{equation}
其中$i=1,2,3,\cdots,n$。可以发现方程和未知数都是$2n+2$个,再补上$l_t < x_1 < \cdots < x_n < u_t$和$\mathcal{E} > 0$的约束条件,理论上可以把解确定下来。
解方程组 #
对于3阶奇多项式($n=1$),原论文给出了解析解,而对于5阶奇多项式($n=2$),原论文给出的是一个迭代求解算法,即先固定$x_1,x_2$求$a,b,c$,然后固定$f_t(x)$的$a,b,c$求$x_1,x_2$,反复迭代,这本质上是Remez算法的一个简化版。
不过,原论文的迭代依赖于求根公式来求$x_1,x_2$,这对于更大的$n$就不大容易操作了。所以这里笔者改变一下求解思路,先以$x_1,x_2,\cdots,x_n$来参数化$f_t'(x_i)$,即定义
\begin{equation}f_t'(x) = k(x^2-x_1^2)(x^2-x_2^2)\cdots (x^2-x_n^2)\end{equation}
然后有$f_t(x) = \int_0^x f_t'(x) dx$,这样我们就用$k$和$x_1,x_2,\cdots,x_n$表示出了$f_t(x)$,继而只需求解方程组
\begin{equation}f_t(l_t) = 1 - \mathcal{E}, \quad f_t(u_t) = 1 + (-1)^n \mathcal{E},\quad f_t(x_i) = 1 + (-1)^{i+1}\mathcal{E}\end{equation}
而避免了解方程$f_t'(x) = 0$。当$n=1$时,我们可以解得
\begin{equation}x_1 = \sqrt{\frac{l_t^2+l_t u_t + u_t^2}{3}},\quad k = -\frac{6}{l_t^2 u_t + l_t u_t^2 + 2x_1^3}\end{equation}
当$n > 1$时,我们可以直接交给Mathematica,例如$n=2$时:
df[x_] = k*(x^2 - x1^2) (x^2 - x2^2);
f[x_] = Integrate[df[x], {x, 0, x}];
sol = NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e,
f[u] == 1 + e, l < x1 < x2 < u, e > 0} /. {l -> 0.001,
u -> 1}, {k, x1, x2, e}, Reals]
f[x] /. sol
有限精度 #
至此,我们似乎已经完成了原始问题的求解?理论上是的,但仅限于无限精度。实际计算时是有限精度的,尤其是Muon优化器用的是bfloat16,精度损失更严重,所以就带来了一些问题。
第一个问题,是每一个$f_t^*$理论上只对区间$[l_t,u_t]$负责,但有限精度下奇异值可能会偏离该区间。当$n$是偶数时(即$f_t^*$是5、9、...阶奇多项式),超出$u_t$就存在发散风险,因为此时的$f_t^*(x)$在$x > u_t$时是单调递增到正无穷的,稍有不慎就会随着迭代而发散。解决办法有两个,一是求$f_t^*$时把$[l_t,u_t]$预留得稍微宽松一些,二是保持区间不变,但求得$f_t^*$后输入要多除一个大于1的数。原论文用的是后者,把$f_t^*(x)$改为$f_t^*(x / 1.01)$。
第二个问题比较隐晦,我们用具体例子来介绍。设$n=2,l_1=0.001,u_1=1$,可以求得$f_1^*$是
\begin{equation}f_1^*(x) = 8.4703 x - 25.1081 x^3 + 18.6293 x^5\end{equation}
其中$x_1 = 0.3674, x_2 = 0.8208, \mathcal{E}=0.9915$。这个解有什么问题呢?根据等值振荡定理,我们知道$f_1^*(x_2) = 1-\mathcal{E} = 0.0085$,即它会把$0.8208$映射成$0.0085$。然而,我们的最终目标是将$(0,1]$的所有数都变成1,所以$f_1^*$会将一个已经很接近目标的$0.8208$映射到非常远离目标的$0.0085$。尽管后面$f_2^*,f_3^*,\cdots$理论上会逐渐把它拉回来,但在有限精度下反复将一个数缩小又放大,累积误差是很可观的。
当然,由等值振荡定理可知,这种振荡行为是无法避免的,我们只能寄望于最大误差$\mathcal{E}$不要太接近于1,从而减缓这种累积误差。不难看出,区间$[l_t,u_t]$越大,理论上就越难拟合,最大误差$\mathcal{E}$将会越近于1,所以论文引入了一个超参数$\lambda \in (0, 1)$,将优化区间从$[l_t,u_t]$改为$[\max(l_t, \lambda u_t),u_t]$,通过限制区间大小来保证$\mathcal{E}$不会太大。(需要提醒读者的是,论文在正文讲解中用的$\lambda$是$0.1$,但附录代码实际用的$\lambda$是$0.024$。)
但这样一来,原本的$l_t$,尤其是我们一开始设置的$l$,不就容易被忽略了?为了解决这个问题,论文引入了“Recenter”技巧,即如果优化区间是$[l_t,u_t]$,那么将会满足$f_t^*(l_t) + f_t^*(u_t) = 2$,而优化区间改为$[\max(l_t, \lambda u_t),u_t]$后就不一定满足了,这时候我们给$f_t^*$乘上$\gamma$,使得它满足这个等式
\begin{equation}\gamma f_t^*(l_t) + \gamma f_t^*(u_t) = 2\qquad \Rightarrow \qquad \gamma = \frac{2}{f_t^*(l_t) + f_t^*(u_t)}\end{equation}
这就把原本$l_t$考虑进去了。
参考代码 #
这是$n=2$时的Mathematica完整代码:
df[x_] = k*(x^2 - x1^2) (x^2 - x2^2);
f[x_] = Integrate[df[x], {x, 0, x}];
sol[l_, u_] :=
NSolve[{f[l] == 1 - e, f[x1] == 1 + e, f[x2] == 1 - e, f[u] == 1 + e,
l < x1 < x2 < u, e > 0, k > 0}, {k, x1, x2, e}]
ff[x_, l_, u_] = f[x]*2/(f[l] + f[u]) // Expand;
lt = 0.001; ut = 1; lambda = 0.02407327424182761;
While[1 - lt > 0.0001,
fff[x_] = ff[x, lt, ut] /. sol[Max[lt, lambda*ut], ut][[1]];
Print[fff[x]];
lt = fff[lt]; ut = 2 - lt]
结果如下($f_t(x) = a_t x + b_t x^3 + c_t x^5$):
\begin{array}{c|ccc}
\hline
t & a\times 1.01 & b\times 1.01^3 & c\times 1.01^5 \\
\hline
\quad 1\quad & 8.28721 & -23.5959 & 17.3004 \\
2 & 4.10706 & -2.94785 & 0.544843 \\
3 & 3.94869 & -2.9089 & 0.551819 \\
4 & 3.31842 & -2.48849 & 0.510049 \\
5 & 2.30065 & -1.6689 & 0.418807 \\
6 & 1.8913 & 1.268 & 0.376804 \\
7 & 1.875 & -1.25 & 0.375 \\
8 & 1.875 & -1.25 & 0.375 \\
\hline
\end{array}
注意这里给出的是还没有做$f_t^*(x / 1.01)$处理的结果,所以实际的$a, b, c$还要在该表基础上多除以$1.01$的$1,3,5$次方。没有直接给出除以$1.01$之后的结果,是因为除以$1.01$前的收敛值$1.875, -1.25, 0.375$($t \geq 7$)显得更简洁明了,便于观察和欣赏。
作者附录的代码整理如下:
import numpy as np
def optimal_quintic(l, u):
assert 0 <= l <= u
if 1 - 5e-6 <= l / u:
# Above this threshold, the equoscillating polynomials
# is numerically equal to...
return (15 / 8) / u, (-10 / 8) / (u**3), (3 / 8) / (u**5)
# This initialization becomes exact as l -> u
q = (3 * l + 1) / 4
r = (l + 3) / 4
E, old_E = np.inf, None
while not old_E or abs(old_E - E) > 1e-15:
old_E = E
LHS = np.array([
[l, l**3, l**5, 1],
[q, q**3, q**5, -1],
[r, r**3, r**5, 1],
[u, u**3, u**5, -1],
])
a, b, c, E = np.linalg.solve(LHS, np.ones(4))
q, r = np.sqrt(
(-3 * b + np.array([-1, 1]) * np.sqrt(9 * b**2 - 20 * a * c)) /
(10 * c)
)
return float(a), float(b), float(c)
def optimal_composition(l, num_iters, cushion=0.02407327424182761):
u = 1
coefficients = []
for _ in range(num_iters):
a, b, c = optimal_quintic(max(l, cushion * u), u)
# Due to cushioning , this may be centered around 1 with
# respect to 0.024*u, u. Recenter it around 1 with respect
# to l, u, meaning find c so that 1 - c*p(l) = c*p(u) - 1:
pl = a * l + b * l**3 + c * l**5
pu = a * u + b * u**3 + c * u**5
rescalar = 2 / (pl + pu)
a *= rescalar
b *= rescalar
c *= rescalar
# Optionally incorporate safety factor here :
# a /= 1.01; b /= 1.01**3; c /= 1.01**5
coefficients.append((a, b, c))
l = a * l + b * l**3 + c * l**5
u = 2 - l
return coefficients
print(*optimal_composition(1e-3, 10), sep="\n")
完成证明 #
最后一节,我们来补上“贪心解正好是它的全局最优解”的证明。
根据等值振荡定理,我们知道$f_t^*$的值域是$[l_{t+1},u_{t+1}]$,其中$l_{t+1}=f_t^*(l_t),u_{t+1}=2-l_{t+1}$,由此可知$T$步贪心解的最大误差是$\mathcal{E}_T = 1 - l_{T+1} = 1 - f_T^*(l_T)$,我们只需要证明$T$步全局最优解的最大误差也只能降到$1 - f_T^*(l_T)$,就可以得到“贪心解正好是它的全局最优解”这个结论。
证明的思路是数学归纳法。假设结论对于$t=1,2,\cdots,T-1$成立,那么$\hat{f} = f_{T-1}^*\circ \cdots \circ f_2^* \circ f_1^*$就是$T-1$步的全局最优解,值域为$[l_T, u_T]$,最大误差为$\mathcal{E}_{T-1}=1-l_T=u_T-1$。另一方面,设$\tilde{f} = \tilde{f}_{T-1}\circ \cdots \circ \tilde{f}_2 \circ \tilde{f}_1$为任意一个$T-1$步解,值域为$[a,b]$,令$c = \frac{2}{a+b}$,那么$c\tilde{f}$的值域则是$[ca,cb]$,显然$ca\leq 1, cb\geq 1$。根据归纳假设,我们有
\begin{equation}\begin{aligned}
1 - ca \geq \mathcal{E}_{T-1} \\
cb - 1 \geq \mathcal{E}_{T-1}
\end{aligned}\qquad\Rightarrow\qquad \frac{a}{b} \leq \frac{1 - \mathcal{E}_{T-1}}{1 + \mathcal{E}_{T-1}} = \frac{l_T}{u_T} \end{equation}
即任意一个$T-1$步解的值域的相对大小,都不小于$T-1$步最优解的值域$[l_T, u_T]$的相对大小。接着我们有
\begin{equation}\begin{aligned}
\min_{f_T} \max_{x\in[l,u]} |f_T(\tilde{f}(x)) - 1| =&\, \min_{f_T} \max_{x\in[a,b]} |f_T(x) - 1| \\
=&\, \min_{f_T} \max_{x\in[a/b,1]} |f_T(x) - 1| \\
\geq &\, \min_{f_T} \max_{x\in[l_T/u_T,1]} |f_T(x) - 1| \\
=&\, \min_{f_T} \max_{x\in[l_T,u_T]} |f_T(x) - 1| \\
=&\, \mathcal{E}_T
\end{aligned}\end{equation}
也就是说,你随便拿一个别的$T-1$步解来,最大误差也顶多只能跟贪心解一样小,所以贪心解的最大误差已经是全局最优的,这就完成了递归证明。上式的关键一步是
\begin{equation}\min_{f_T} \max_{x\in[a,b]} |f_T(x) - 1| = \min_{f_T} \max_{x\in[a/b,1]} |f_T(x) - 1|\end{equation}
这是因为我们总可以设$g_T(y) = f_T(b y)$,$g_T$依然能代表任意同阶的奇多项式,所以$g_T$和$f_T$都在同一函数空间中,因此记号可以换用,即
\begin{equation}\min_{f_T}\max_{x\in[a,b]} |f_T(x) - 1| = \min_{g_T}\max_{y\in[a/b,1]} |g_T(y) - 1|= \min_{f_T}\max_{x\in[a/b,1]} |f_T(x) - 1|\end{equation}
文章小结 #
本文介绍了为msign算子寻找更好的Newton-Schulz迭代的最新进展,它通过等值振荡定理和贪心转换,直接求出理论上的最优解,整个过程相当硬核,值得学习一波。
转载到请包括本文地址:https://kexue.fm/archives/10996
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 05, 2025). 《msign算子的Newton-Schulz迭代(下) 》[Blog post]. Retrieved from https://kexue.fm/archives/10996
@online{kexuefm-10996,
title={msign算子的Newton-Schulz迭代(下)},
author={苏剑林},
year={2025},
month={Jun},
url={\url{https://kexue.fm/archives/10996}},
}
June 6th, 2025
牛啊
June 13th, 2025
此外,由于目标函数是1,所以$f^∗_t(x)$在 $x$ 处的斜率大于零,所以$l_t$ 只能是$f^∗_t(x)$的最小值点.
此处不太理解,可否详细一点?多谢!
写漏了,是$f^∗_t(x)$在 $x=0$ 处的斜率大于零,如果$l_t$是最大值点,那么它左边是增的,右边是减的,就只能占用一个$f^∗_t(x)$的极值点(驻点)名额了,可是如果边界点都占用了一个极值点名额的话,那么就凑不够$n+2$个最大值点了。
针对$n=2$情况,有没有可能在$l_t$处是最大值$1+\epsilon$ ($\epsilon\ge0$),各个$x_k$ (where $k=1,2$) 依次是$1-\epsilon$,$1+\epsilon$ 最后在$u_t$处是最小值$1-\epsilon$?似乎也是可以的?
如果$l_t$处是最大值,那么$l_t$右侧就是下降对吧?但$x=0$处是单调递增的,这意味着$l_t$左侧要不已经出现了一个极大值点,要不然$l_t$本身就是一个极大值点,不管是哪种情况,都需要“占用”掉一个极值点名额,导致凑不够$n+2$个误差最大的点。
关键:对于光滑函数来说,单调递增和递减之间,必然有一个极大值点。
我可能错过了最重要的点,为啥在x=0处是单调递增的?
我们的目标是1,$f^∗_t(0)=0$,如果$x=0$处还单调递减,而$x=l_t$处大于零,那么不就表明$(0,l_t)$区间有一个极小值点,这同样要占去一个极值名额。