重温SSM(一):线性系统和HiPPO矩阵
By 苏剑林 | 2024-05-24 | 41530位读者 |前几天,笔者看了几篇介绍SSM(State Space Model)的文章,才发现原来自己从未认真了解过SSM,于是打算认真去学习一下SSM的相关内容,顺便开了这个新坑,记录一下学习所得。
SSM的概念由来已久,但这里我们特指深度学习中的SSM,一般认为其开篇之作是2021年的S4,不算太老,而SSM最新最火的变体大概是去年的Mamba。当然,当我们谈到SSM时,也可能泛指一切线性RNN模型,这样RWKV、RetNet还有此前我们在《Google新作试图“复活”RNN:RNN能否再次辉煌?》介绍过的LRU都可以归入此类。不少SSM变体致力于成为Transformer的竞争者,尽管笔者并不认为有完全替代的可能性,但SSM本身优雅的数学性质也值得学习一番。
尽管我们说SSM起源于S4,但在S4之前,SSM有一篇非常强大的奠基之作《HiPPO: Recurrent Memory with Optimal Polynomial Projections》(简称HiPPO),所以本文从HiPPO开始说起。
基本形式 #
先插句题外话,上面提到的SSM代表作HiPPO、S4、Mamba的一作都是Albert Gu,他还有很多篇SSM相关的作品,毫不夸张地说,这些工作筑起了SSM大厦的基础。不论SSM前景如何,这种坚持不懈地钻研同一个课题的精神都值得我们由衷地敬佩。
言归正传。对于事先已经对SSM有所了解的读者,想必知道SSM建模所用的是线性ODE系统:
\begin{equation}\begin{aligned}
x'(t) =&\, A x(t) + B u(t) \\
y(t) =&\, C x(t) + D u(t)
\end{aligned}\label{eq:ode}\end{equation}
其中$u(t)\in\mathbb{R}^{d_i}, x(t)\in\mathbb{R}^{d}, y(t)\in\mathbb{R}^{d_o}, A\in\mathbb{R}^{d\times d}, B\in\mathbb{R}^{d\times d_i}, C\in\mathbb{R}^{d_o\times d}, D\in\mathbb{R}^{d_o\times d_i}
$。当然我们也可以将它离散化,那么就变成一个线性RNN模型,这部分我们在后面的文章再展开。不管离散化与否,其关键词都是“线性”,那么马上就有一个很自然的问题:为什么是线性系统?线性系统够了吗?
我们可以从两个角度回答这个问题:线性系统既足够简单,也足够复杂。简单是指从理论上来说,线性化往往是复杂系统的一个最基本近似,所以线性系统通常都是无法绕开的一个基本点;复杂是指即便如此简单的系统,也可以拟合异常复杂的函数,为了理解这一点,我们只需要考虑一个$\mathbb{R}^4$的简单例子:
\begin{equation}
x'(t) =\begin{pmatrix} 1 & 0 & 0 & 0 \\
0 & -1 & 0 & 0 \\
0 & 0 & 0 & 1 \\
0 & 0 & -1 & 0
\end{pmatrix}x(t)\end{equation}
这个例子的基本解是$x(t) = (e^t, e^{-t}, \sin t, \cos t)$。这意味着什么呢?意味着只要$d$足够大,该线性系统就可以通过指数函数和三角函数的组合来拟合足够复杂的函数,而我们知道拟合能力很强的傅里叶级数也只不过是三角函数的组合,如果在加上指数函数显然就更强了,因此可以想象线性系统也有足够复杂的拟合能力。
当然,这些解释某种意义上都是“马后炮”。HiPPO给出的结果更加本质:当我们试图用正交基去逼近一个动态更新的函数时,其结果就是如上的线性系统。这意味着,HiPPO不仅告诉我们线性系统可以逼近足够复杂的函数,还告诉我们怎么去逼近,甚至近似程度如何。
有限压缩 #
接下来,我们都只考虑$d_i=1$的特殊情形,$d_i > 1$只不过是$d_i=1$时的平行推广。此时,$u(t)$的输出是一个标量,进一步地,作为开头我们先假设$t\in[0, 1]$,HiPPO的目标是:用一个有限维的向量来储存这一段$u(t)$的信息。
看上去这是一个不大可能的需求,因为$t\in[0,1]$意味着$u(t)$可能相当于无限个点组成的向量,压缩到一个有限维的向量可能严重失真。不过,如果我们对$u(t)$做一些假设,并且允许一些损失,那么这个压缩是有可能做到的,并且大多数读者都已经尝试过。比如,当$u(t)$在某点$n+1$阶可导的,它对应的$n$阶泰勒展开式往往是$u(t)$的良好近似,于是我们可以只储存展开式的$n+1$个系数来作为$u(t)$的近似表征,这就成功将$u(t)$压缩为一个$n+1$维向量。
当然,对于实际遇到的数据来说,“$n+1$阶可导”这种条件可谓极其苛刻,我们通常更愿意使用在平方可积条件下的正交函数基展开,比如傅里叶(Fourier)级数,它的系数计算公式为
\begin{equation}c_n = \int_0^1 u(t) e^{-2i\pi n t}dt \label{eq:fourier-coef-1}\end{equation}
这时候取一个足够大的整数$N$,只保留$|n|\leq N$的系数,那么就将$u(t)$压缩为一个$2N + 1$维的向量了。
接下来,问题难度就要升级了。刚才我们说$t\in[0,1]$,这是一个静态的区间,而实际中$u(t)$代表的是持续采集的信号,所以它是不断有新数据进入的,比如现在我们近似了$[0,1]$区间的数据,马上就有$[1,2]$的数据进来,你需要更新逼近结果来试图记忆整个$[0,2]$区间,接下来是$[0,3]$、$[0,4]$等等,这我们称为“在线函数逼近”。而上面的傅里叶系数公式$\eqref{eq:fourier-coef-1}$,只适用于区间$[0,1]$,因此需要将它进行推广。
为此,我们设$t\in[0,T]$,$s\mapsto t_{\leq T}(s)$是$[0,1]$到$[0,T]$的一个映射,那么$u(t_{\leq T}(s))$作为$s$的函数时,它的定义区间就是$[0,1]$,于是就可以复用式$\eqref{eq:fourier-coef-1}$:
\begin{equation}c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds \label{eq:fourier-coef-2}\end{equation}
这里我们已经给系数加了标记$(T)$,以表明此时的系数会随着$T$的变化而变化。
线性初现 #
能将$[0,1]$映射到$[0,T]$的函数有无穷多,而最终结果也因$t_{\leq T}(s)$而异,一些比较直观且相对简单的选择如下:
1、$t_{\leq T}(s) = sT$,即将$[0,1]$均匀地映射到$[0,T]$;
2、注意$t_{\leq T}(s)$并不必须是满射,所以像$t_{\leq T}(s)=s + T - 1$也是允许的,这意味着只保留了最邻近窗口$[T-1,T]$的信息,丢掉了更早的部分,更一般地有$t_{\leq T}(s)=sw + T - w$,其中$w$是一个常数,这意味着$T-w$前的信息被丢掉了;
3、也可以选择非均匀映射,比如$t_{\leq T}(s) = T\sqrt{s}$,它同样是$[0,1]$到$[0,T]$的满射,但$s=1/4$时就映射到$T/2$了,这意味着我们虽然关注全局的历史,但同时更侧重于$T$时刻附近的信息。
现在我们以$t_{\leq T}(s)=sw + T - w$为例,代入式$\eqref{eq:fourier-coef-2}$得到
\begin{equation}c_n(T) = \int_0^1 u(sw + T - w) e^{-2i\pi n s}ds\end{equation}
现在我们两边求关于$T$的导数:
\begin{equation}\begin{aligned}
\frac{d}{dT}c_n(T) =&\, \int_0^1 u'(sw + T - w) e^{-2i\pi n s}ds \\
=&\, \left.\frac{1}{w} u(sw + T - w) e^{-2i\pi n s}\right|_{s=0}^{s=1} + \frac{2i\pi n}{w}\int_0^1 u(sw + T - w) e^{-2i\pi n s}ds \\
=&\, \frac{1}{w} u(T) - \frac{1}{w} u(T-w) + \frac{2i\pi n}{w} c_n(T) \\
\end{aligned}\label{eq:fourier-dc}\end{equation}
其中第二个等号我们用了分部积分公式。由于我们只保留了$|n|\leq N$的系数,所以根据傅立叶级数的公式,可以认为如下是$u(sw + T - w)$的一个良好近似:
\begin{equation}u(sw + T - w) \approx \sum_{k=-N}^{k=N} c_k(T) e^{2i\pi k s}ds\end{equation}
那么$u(T - w) = u(sw + T - w)|_{s=0}\approx \sum\limits_{k=-N}^{k=N} c_k(T)$,代入式$\eqref{eq:fourier-dc}$得:
\begin{equation}\frac{d}{dT}c_n(T) \approx \frac{1}{w} u(T) - \frac{1}{w} \sum_{k=-N}^{k=N} c_k(T) + \frac{2i\pi n}{w} c_n(T)\end{equation}
将$T$换成$t$,然后所有的$c_n(t)$堆在一起记为$x(t) = (c_{-N},c_{-(N-1)},\cdots,c_0,\cdots,c_{N-1},c_N)$,并且不区分$\approx$和$=$,那么就可以写出
\begin{equation}x'(t) = Ax(t) + Bu(t),\quad A_{n,k} = \left\{\begin{array}{l}(2i\pi n - 1)/w, &k=n \\ -1/w,&k\neq n\end{array}\right.,\quad B_n = 1/w\end{equation}
这就出现了如式$\eqref{eq:ode}$所示的线性ODE系统。即当我们试图用傅里叶级数去记忆一个实时函数的最邻近窗口内的状态时,结果自然而言地导致了一个线性ODE系统。
一般框架 #
当然,目前只是选择了一个特殊的$t_{\leq T}(s)$,换一个$t_{\leq T}(s)$就不一定有这么简单的结果了。此外,傅里叶级数的结论是在复数范围内的,进一步实数化也可以,但形式会变得复杂起来。所以,我们要将上一节的过程推广成一个一般化的框架,从而得到更一般、更简单的纯实数结论。
设$t\in[a,b]$,并且有目标函数$u(t)$和函数基$\{g_n(t)\}_{n=0}^N$,我们希望有后者的线性组合来逼近前者,目标是最小化$L_2$距离:
\begin{equation}\mathop{\text{argmin}}_{c_1,\cdots,c_N}\int_a^b \left[u(t) - \sum_{n=0}^N c_n g_n(t)\right]^2 dt\end{equation}
这里我们主要在实数范围内考虑,所以方括号直接平方就行,不用取模。更一般化的目标函数还可以再加个权重函数$\rho(t)$,但我们这里就不考虑了,毕竟HiPPO的主要结论其实也没考虑这个权重函数。
对目标函数展开,得到
\begin{equation}\int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{m=0}^N\sum_{n=0}^N c_m c_n \int_a^b g_m(t) g_n(t) dt\end{equation}
这里我们只考虑标准正交函数基,其定义为$\int_a^b g_m(t) g_n(t) dt = \delta_{m,n}$,$\delta_{m,n}$是克罗内克δ函数,此时上式可以简化成
\begin{equation}\int_a^b u^2(t) dt - 2\sum_{n=0}^N c_n \int_a^b u(t) g_n(t)dt + \sum_{n=0}^N c_n^2 \end{equation}
这只是一个关于$c_n$的二次函数,它的最小值是有解析解的:
\begin{equation}c^*_n = \int_a^b u(t) g_n(t)dt\end{equation}
这也被称为$u(t)$与$g_n(t)$的内积,它是有限维向量空间的内积到函数空间的平行推广。简单起见,在不至于混淆的情况下,我们默认$c_n$就是$c^*_n$。
接下来的处理跟上一节是一样的,我们要对一般的$t\in[0, T]$考虑$u(t)$的近似,那么找一个$[a,b]$到$[0,T]$的映射$s\mapsto t_{\leq T}(s)$,然后计算系数
\begin{equation}c_n(T) = \int_a^b u(t_{\leq T}(s)) g_n(s) ds\end{equation}
同样是两边求$T$的导数,然后用分部积分法
\begin{equation}\scriptsize\begin{aligned}
\frac{d}{dT}c_n(T) =&\, \int_a^b u'(t_{\leq T}(s)) \frac{\partial t_{\leq T}(s)}{\partial T} g_n(s) ds = \int_a^b \left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s) d u(t_{\leq T}(s)) \\
=&\,\left.u(t_{\leq T}(s))\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right|_{s=a}^{s=b} - \int_a^b u(t_{\leq T}(s)) \,d\left[\left(\frac{\partial t_{\leq T}(s)}{\partial T}\left/\frac{\partial t_{\leq T}(s)}{\partial s}\right.\right) g_n(s)\right]
\end{aligned}\label{eq:hippo-base}\end{equation}
请勒让德 #
接下来的计算,就依赖于$g_n(t)$和$t_{\leq T}(s)$的具体形式了。HiPPO的全称是High-order Polynomial Projection Operators,第一个P正是多项式(Polynomial)的首字母,所以HiPPO的关键是选取多项式为基。现在我们请出继傅里叶之后又一位大牛——勒让德(Legendre),接下来我们要选取的函数基正是以他命名的“勒让德多项式”。
勒让德多项式$p_n(t)$是关于$t$的$n$次函数,定义域为$[-1,1]$,满足
\begin{equation}\int_{-1}^1 p_m(t) p_n(t) dt = \frac{2}{2n+1}\delta_{m,n}\end{equation}
所以$p_n(t)$之间只是正交,还不是标准(平分积分为1),$g_n(t)=\sqrt{\frac{2n+1}{2}} p_n(t)$才是标准正交基。
当我们对函数基$\{1,t,t^2,\cdots, t^n\}$执行施密特正交化时,其结果正是勒让德多项式。相比傅里叶基,勒让德多项式的好处是它是纯粹定义在实数空间中的,并且多项式的形式能够有助于简化部分$t_{\leq T}(s)$的推导过程,这一点我们后面就可以看到。勒让德多项式有很多不同的定义和性质,这里我们不一一展开,有兴趣的读者自行看链接中维基百科介绍即可。
接下来我们用到两个递归公式来推导一个恒等式,这两个递归公式是
\begin{align}
p_{n+1}'(t) - p_{n-1}'(t) = (2n+1)p_n(t) \label{eq:leg-r1}\\[5pt]
p_{n+1}'(t) = (n + 1)p_n(t) + t p_n'(t) \label{eq:leg-r2}\\
\end{align}
由第一个公式$\eqref{eq:leg-r1}$迭代得到:
\begin{equation}\begin{aligned}
p_{n+1}'(t) =&\, (2n+1)p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \cdots \\
=&\, \sum_{k=0}^n (2k+1) \chi_{n-k} p_k(t)
\end{aligned}\label{eq:leg-dot}\end{equation}
其中当$k$是偶数时$\chi_k=1$否则$\chi_k=0$。代入第二个公式$\eqref{eq:leg-r2}$得到
\begin{equation}t p_n'(t) = n p_n(t) + (2n-3)p_{n-2}(t) + (2n-7)p_{n-4}(t) + \cdots\end{equation}
继而有
\begin{equation}\begin{aligned}
(t+1) p_n'(t) =&\, n p_n(t) + (2n-1)p_{n-1}(t) + (2n-3)p_{n-2}(t) + \cdots\\
=&\,-(n+1) p_n(t) + \sum_{k=0}^n (2k + 1) p_k(t)
\end{aligned}\label{eq:leg-dot-t1}\end{equation}
这些就是等会要用到的恒等式。此外,勒让德多项式满足$p_n(1)=1,p_n(-1)=(-1)^n$,这个边界值后面也会用到。
正如$n$维空间中不止有一组正交基也一样,正交多项式也不止有勒让德多项式一种,比如还有切比雪夫(Chebyshev)多项式,如果算上加权的目标函数(即$\rho(t)\not\equiv 1$),还有拉盖尔多项式等,这些在原论文中都有提及,但HiPPO的主要结论还是基于勒让德多项式展开的,所以剩余部分这里也不展开讨论了。
邻近窗口 #
完成准备工作后,我们就可以代入具体的$t_{\leq T}(s)$进行计算了,计算过程跟傅里叶级数的例子大同小异,只不过基函数换成了勒让德多项式构造的标准正交基$g_n(t)=\sqrt{\frac{2n+1}{2}} p_n(t)$。作为第一个例子,我们同样先考虑只保留最邻近窗口的信息,此时$t_{\leq T}(s) = (s + 1)w / 2 + T - w$将$[-1,1]$映射到$[T-w,T]$,原论文将这种情形称为“LegT(Translated Legendre)”。
直接代入式$\eqref{eq:hippo-base}$,马上得到
\begin{equation}\small\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{w}\left[u(T) - (-1)^n u(T-w)\right] - \frac{2}{w}\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds\end{equation}
我们首先处理$u(T-w)$项,跟傅里叶级数那里同样的思路,我们截断$n\leq N$作为$u((s + 1)w / 2 + T - w)$的一个近似:
\begin{equation}u((s + 1)w / 2 + T - w)\approx \sum_{k=0}^N c_k(T)g_k(s)\end{equation}
从而有$u(T-w)\approx \sum\limits_{k=0}^N c_k(T)g_k(-1) = \sum\limits_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}$。接着,利用式$\eqref{eq:leg-dot}$得到
\begin{equation}\begin{aligned}
&\,\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds \\
=&\,\int_{-1}^1 u((s + 1)w / 2 + T - w) \sqrt{\frac{2n+1}{2}} p_n'(s) ds \\
=&\, \int_{-1}^1 u((s + 1)w / 2 + T - w)\sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} (2k+1) \chi_{n-1-k} p_k(s)\right]ds \\
=&\, \int_{-1}^1 u((s + 1)w / 2 + T - w)\sqrt{\frac{2n+1}{2}}\left[\sum_{k=0}^{n-1} \sqrt{2(2k+1)} \chi_{n-1-k} g_k(s)\right]ds \\
=&\, \sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T)
\end{aligned}\end{equation}
将这些结果整合起来,就有
\begin{equation}\begin{aligned}
\frac{d}{dT}c_n(T) \approx &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2(2n+1)}}{w} (-1)^n \overbrace{\sum\limits_{k=0}^N (-1)^k c_k(T) \sqrt{\frac{2k+1}{2}}}^{u(T-w)} \\
&\quad- \frac{2}{w}\overbrace{\sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T)}^{\int_{-1}^1 u((s + 1)w / 2 + T - w) g_n'(s) ds} \\[12pt]
= &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=0}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\
&\quad- \frac{2}{w}\sqrt{2n+1}\sum_{k=0}^{n-1} \sqrt{2k+1} \chi_{n-1-k} c_k(T) \\[12pt]
= &\, \frac{\sqrt{2(2n+1)}}{w}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=n}^N (-1)^{n-k} c_k(T) \sqrt{2k+1} \\
&\quad- \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \sqrt{2k+1} \underbrace{\left(2\chi_{n-1-k} + (-1)^{n-k}\right)}_{\equiv 1}c_k(T) \\
\end{aligned}\label{eq:leg-t}\end{equation}
再次地,将$T$换回$t$,并将所有的$c_n(t)$堆在一起记为$x(t) = (c_0,c_1,\cdots,c_N)$,那么根据上式可以写出
\begin{equation}\begin{aligned}
x'(t) =&\, Ax(t) + Bu(t)\\[8pt]
\quad A_{n,k} =&\, -\frac{1}{w}\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ (-1)^{n-k}\sqrt{(2n+1)(2k+1)}, &k \geq n\end{array}\right.\\[8pt]
B_n =&\, \frac{1}{w}\sqrt{2(2n+1)}
\end{aligned}\label{eq:leg-t-hippo-1}\end{equation}
我们还可以给每个$c_n(T)$都引入一个缩放因子,来使得上述结果更一般化。比如我们设$c_n(T) = \lambda_n \tilde{c}_n(T)$,代入式$\eqref{eq:leg-t}$整理得
\begin{equation}\begin{aligned}
\frac{d}{dt}\tilde{c}_n(T) \approx &\, \frac{\sqrt{2(2n+1)}}{w\lambda_n}u(T) - \frac{\sqrt{2n+1}}{w} \sum\limits_{k=n}^N (-1)^{n-k} \tilde{c}_k(T) \frac{\lambda_k\sqrt{2k+1}}{\lambda_n} \\
&\quad- \frac{\sqrt{2n+1}}{w}\sum_{k=0}^{n-1} \frac{\lambda_k\sqrt{2k+1}}{\lambda_n} \tilde{c}_k(T) \\
\end{aligned}\end{equation}
如果取$\lambda_n = \sqrt{2}$,那么$A$不变,$B_n = \frac{1}{w}\sqrt{2n+1}$,这就对齐了原论文的结果,如果取$\lambda_n = \frac{2}{\sqrt{2n+1}}$,那么就得到了Legendre Memory Units中的结果
\begin{equation}\begin{aligned}
x'(t) =&\, Ax(t) + Bu(t)\\[8pt]
\quad A_{n,k} =&\, -\frac{1}{w}\left\{\begin{array}{l}2n+1, &k < n \\ (-1)^{n-k}(2n+1), &k \geq n\end{array}\right.\\[8pt]
B_n =&\, \frac{1}{w}(2n+1)
\end{aligned}\label{eq:leg-t-hippo-2}\end{equation}
这些形式在理论上都是等价的,但可能存在不同的数值稳定性。比如一般来说当$u(t)$的性态不是特别糟糕时,我们可以预期$n$越大,$|c_n|$的值就相对越小,这样直接用$c_n$的话$x(t)$向量的每个分量的尺度就不大对等,这样的系统在实际计算时容易出现数值稳定问题,而取$\lambda_n = \frac{2}{\sqrt{2n+1}}$改用$\tilde{c}_n$的话意味着数值小的分量会被适当放大,可能有助于缓解多尺度问题从而使得数值计算更稳定。
整个区间 #
现在我们继续计算另一个例子:$t_{\leq T}(s) = (s + 1)T / 2$,它将$[-1,1]$均匀映射到$[0,T]$,这意味着我们没有舍弃任何历史信息,并且平等地对待所有历史,原论文将这种情形称为“LegS(Scaled Legendre)”。
同样地,通过代入式$\eqref{eq:hippo-base}$得到
\begin{equation}\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{T}u(T) - \frac{1}{T}\int_{-1}^1 u((s + 1)T / 2) \left[g_n(s) + (s+1) g_n'(s)\right] ds\end{equation}
利用公式$\eqref{eq:leg-dot-t1}$得到
\begin{equation}\begin{aligned}
&\,\int_{-1}^1 u((s + 1)T / 2) \left[g_n(s) + (s+1) g_n'(s)\right] ds \\
=&\,c_n(T) + \int_{-1}^1 u((s + 1)T / 2) (s+1) g_n'(s) ds \\
=&\, c_n(T) + \int_{-1}^1 u((s + 1)T / 2)(s+1) \sqrt{\frac{2n+1}{2}} p_n'(s) \\
=&\, c_n(T) + \int_{-1}^1 u((s + 1)T / 2)\sqrt{\frac{2n+1}{2}}\left[-(n+1) p_n(s) + \sum_{k=0}^n (2k + 1) p_k(s)\right] ds \\
=&\, c_n(T) + \int_{-1}^1 u((s + 1)T / 2)\left[-(n+1) g_n(s) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} g_k(s)\right] ds \\
=&\, -n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} c_k(T) \\
\end{aligned}\end{equation}
于是有
\begin{equation}\frac{d}{dT}c_n(T) = \frac{\sqrt{2(2n+1)}}{T}u(T) - \frac{1}{T}\left(-n c_n(T) + \sum_{k=0}^n \sqrt{(2n+1)(2k + 1)} c_k(T)\right)\label{eq:leg-s}\end{equation}
将$T$换回$t$,将所有的$c_n(t)$堆在一起记为$x(t) = (c_0,c_1,\cdots,c_N)$,那么根据上式可以写出
\begin{equation}\begin{aligned}
x'(t) =&\, \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt]
\quad A_{n,k} =&\, -\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ n+1, &k = n \\
0, &k > n\end{array}\right.\\[8pt]
B_n =&\, \sqrt{2(2n+1)}
\end{aligned}\label{eq:leg-s-hippo}\end{equation}
引入缩放因子来一般化结果也是可行的:设$c_n(T) = \lambda_n \tilde{c}_n(T)$,代入式$\eqref{eq:leg-t}$整理得
\begin{equation}\frac{d}{dT}\tilde{c}_n(T) = \frac{\sqrt{2(2n+1)}}{T\lambda_n}u(T) - \frac{1}{T}\left(-n \tilde{c}_n(T) + \sum_{k=0}^n \frac{\sqrt{(2n+1)(2k + 1)}\lambda_k}{\lambda_n} \tilde{c}_k(T)\right)\end{equation}
取$\lambda_n=\sqrt{2}$就可以让$A$不变,$B$变为$B_n = \sqrt{2n+1}$,就对齐了原论文的结果。如果取$\lambda_n=\sqrt{\frac{2}{2n+1}}$,就可以像上一节LegT的结果一样去掉根号
\begin{equation}\begin{aligned}
x'(t) =&\, \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt]
\quad A_{n,k} =&\, -\left\{\begin{array}{l}2n+1, &k < n \\ n+1, &k = n \\
0, &k > n\end{array}\right.\\[8pt]
B_n =&\, 2n+1
\end{aligned}\label{eq:leg-s-hippo-2}\end{equation}
但原论文没有考虑这种情况,原因不详。
延伸思考 #
回顾Leg-S的整个推导,我们可以发现其中关键一步是将$(s+1) g_n'(s)$拆成$g_0(s),g_1(s),\cdots,g_n(s)$的线性组合,对于正交多项式来说,$(s+1) g_n'(s)$是一个$n$次多项式,所以这种拆分必然可以精确成立,但如果是傅立叶级数的情况,$g_n(s)$是指数函数,此时类似的拆分做不到了,至少不能精确地做到,所以可以说选取正交多项式为基的根本目的是简化后面推导。
特别要指出的是,HiPPO是一个自下而上的框架,它并没有一开始就假设系统必须是线性的,而是从正交基逼近的角度反过来推出其系数的动力学满足一个线性ODE系统,这样一来我们就可以确信,只要认可所做的假设,那么线性ODE系统的能力就是足够的,而不用去担心线性系统的能力限制了你的发挥。
当然,HiPPO对于每一个解所做的假设及其物理含义也很清晰,所以对于重用了HiPPO矩阵的SSM,它怎么储存历史、能储存多少历史,从背后的HiPPO假设就一清二楚。比如LegT就是只保留$w$大小的最邻近窗口信息,如果你用了LegT的HiPPO矩阵,那么就类似于一个Sliding Window Attention;而LegS理论上可以捕捉全部历史,但这有个分辨率问题,因为$x(t)$的维度代表了拟合的阶数,它是一个固定值,用同阶的函数基去拟合另一个函数,肯定是区间越小越准确,区间越大误差也越大,这就好比为了一次性看完一幅大图,那么我们必须站得更远,从而看到的细节越少。
诸如RWKV、LRU等模型,并没有重用HiPPO矩阵,而是改为可训练的矩阵,原则上具有更多的可能性来突破瓶颈,但从前面的分析大致上可以感知到,不同矩阵的线性ODE只是函数基不同,但本质上可能都只是有限阶函数基逼近的系数动力学。既然如此,分辨率与记忆长度就依然不可兼得,想要记忆更长的输入并且保持效果不变,那就只能增加整个模型的体量(即相当于增加hidden_size),这大概是所有线性系统的特性。
文章小结 #
本文以尽可能简单的方式重复了《HiPPO: Recurrent Memory with Optimal Polynomial Projections》(简称HiPPO)的主要推导。HiPPO通过适当的记忆假设,自下而上地导出了线性ODE系统,并且针对勒让德多项式的情形求出了相应的解析解(HiPPO矩阵),其结果被后来诸多SSM(State Space Model)使用,可谓是SSM的重要奠基之作。
转载到请包括本文地址:https://kexue.fm/archives/10114
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (May. 24, 2024). 《重温SSM(一):线性系统和HiPPO矩阵 》[Blog post]. Retrieved from https://kexue.fm/archives/10114
@online{kexuefm-10114,
title={重温SSM(一):线性系统和HiPPO矩阵},
author={苏剑林},
year={2024},
month={May},
url={\url{https://kexue.fm/archives/10114}},
}
June 19th, 2024
感谢苏神的分享!
有一个疑问,(8)是如何推导到9的呢?什么叫做堆在一起呢?
感谢解答!
就是定义$x(t) = (c_{-N},c_{-(N-1)},\cdots,c_0,\cdots,c_{N-1},c_N)$的意思。
July 1st, 2024
公式6,第一行,u' 应该是关于 T 的导数吧,第二行应用分布积分时, u'的含义又变成了关于s的导数,这块没懂,求苏神解答
一般情况下,$f'()$是指对括号内的整体求导,比如$f'(g(x)) = \frac{df(g(x))}{dg(x)}$。
至于“第二行应用分布积分时, u'的含义又变成了关于s的导数”这是不存在的。
July 2nd, 2024
搞懂了,u'时关于T的导数,但同时也是关于s的函数,其原函数为1/w * u(sw+t-w)
@苏剑林|comment-24660
没太理解,等式两边都是关于T求导。为什么后面的u是关于整个sw+t-w的导数呢?
August 12th, 2024
校友你好,有个疑惑,原paper中在[0,T]上的逼近是直接通过[0,T]上的正交基得到,但是需要有定义一个测度和相应的权重函数w。在第26页底部,需要对w求关于t的偏导,我截了图在这里:https://app.gemoo.com/share/image-annotation/681494723035095040?codeId=Pak7lVdLWOrxN&origin=screenshotuploader&card=681494722674384896
我不理解底下这个偏导的结果,感觉w在边界上(x=t, x=t-theta)甚至不是连续的,为什么会有这个偏导呢?这个偏导是常规意义的那种吗
你截图的权重函数是$\omega(t,x)=\frac{1}{\theta}\mathbb{I}_{[t-\theta,t]}$,这个权重函数理解为:被积函数乘以$\frac{1}{\theta}$,然后积分改为$\int_{t-\theta}^t$。即
$$\int \omega(t,x) f(x) dx = \frac{1}{\theta}\int_{t-\theta}^t f(x)dx$$
如果 $\omega(t, x)$ 直接出现在积分中,我确实是这么理解的。但是后面对$\frac{d}{dt}c_n(t)$的推导中,积分里的是 $\frac{\partial}{\partial t}\omega(t, x)$,可以看后面一页:https://app.gemoo.com/share/home?codeId=M01rVpnZxppnN
因此作者先给出 $\frac{\partial}{\partial t}\omega(t, x) = \frac{1}{\theta}\delta_t - \frac{1}{\theta}\delta_{t-\theta}$,这个偏导带入积分中就变成了在 $t$ 和 $t-\theta$ 这两点的point mass了。但是我无法理解为什么这个偏导是这样的,会不会在RS积分的意义下有什么特殊性。。。
根据
$$\int \omega(t,x) f(x) dx = \frac{1}{\theta}\int_{t-\theta}^t f(x)dx$$
我们有
$$\int \frac{\partial}{\partial t}\omega(t,x) f(x) dx = \frac{\partial}{\partial t}\int \omega(t,x) f(x) dx = \frac{1}{\theta}\frac{\partial}{\partial t}\int_{t-\theta}^t f(x)dx$$
假设$f(x)$的原函数是$F(x)$,那么
$$\int_{t-\theta}^t f(x)dx = F(t) - F(t-\theta)$$
所以
$$\frac{\partial}{\partial t}\int_{t-\theta}^t f(x)dx = F'(t) - F'(t-\theta) = f(t) - f(t-\theta)$$
即
$$\int \frac{\partial}{\partial t}\omega(t,x) f(x) dx = \frac{1}{\theta}[f(t) - f(t-\theta)] = \frac{1}{\theta}\int [\delta(x-t) - \delta(x-(t-\theta))]f(x)dx$$
这里的$\delta(x-t),\delta(x-(t-\theta))$是狄拉克delta函数,即前述的$\delta_t,\delta_{t-\theta}$,所以由此得
$$\frac{\partial}{\partial t}\omega(t,x) = \frac{1}{\theta}(\delta_t-\delta_{t-\theta})$$
这是比较严谨的求法。
简单一点也可以,注意到狄拉克delta函数的积分其实是阶跃函数:
$$\int_{-\infty}^t \delta(x) dx = H(t) = \left\{\begin{aligned}1,t\geq 0 \\ 0, t < 0\end{aligned}\right.$$
也就是阶跃函数的导数是狄拉克函数,而前述$\omega(t,x)$可以等价地写成
$$\omega(t,x) = \frac{1}{\theta}[H(x - (t-\theta)) + 1 - H(x - t)]$$
感谢苏神,我懂了。我这是补习了一下。。感觉这里的一个关键点/困惑点就在于狄拉克函数不是一个正常意义的函数,否则他的积分应该是0。但是在狄拉克函数是按照一种极限意义来定义的,所以他的积分是阶跃函数,既有等于1的部分。
嗯嗯,从普通函数的视角去看待狄拉克函数确实是这样子。
August 12th, 2024
式(6)需要拆成两步去看,注意分部积分要保证里面也是ds项,才能和外面的ds套。
第一步:先把dT操作掉:左式=du/dT=u'(sw+T-w)=u(sw+T-w),dT就没了。
第二步:dT没了再对ds用分部积分,此时左式u(sw+T-w)看作是1/w*u(sw+T-w)的导数,因此直接对s做分部积分就行了。
September 24th, 2024
苏老师您好,读了一些关于把RNN看作ODE的系统的一些文章模型之后,我有一个疑惑,在数学模型里面,我们估计一个和时间有关的函数的时候,首先都会对函数的光滑性质做一些假设,然后设计数值估算欢迎整个函数的方法。因为从数学转到像语言模型这一类问题,我不太理解,如果Hidden state$h_t$对应$x(t)$,在语言模型的序列研究里面,他们有类似函数连续性质,和ODE系统稳定性质的定义吗?跟学计算机的同学交流,我无法给他们解释两个如果$h(t)$是连续的,那么中间的是什么。
像SSM那样,基于ODE离散化来推导一个RNN,纯粹是用一个中间推导过程,当它用于离散对象如自然语言的建模时,建模完成后它就是一个纯粹的离散模型了,脱离了ODE原本的连续性质,所以也就不存在中间状态的问题。
November 28th, 2024
苏老师您好,能否请教一下最后那个An,k矩阵,这里的n和N具体是什么含义。我个人的理解是,n是多项式的最大阶数,N是指多项式的数量,但似乎不太对
具体来说就是,在式(26)中,我能够理解可以把所有的ck(t)堆在一起变成x(t),一共有N个ck(t)。但前面的导数不是cn(t)的导数吗?为什么可以直接写成x(t)的导数,在我的理解里,cn(t)的导数不应该只是x'(t)中的一部分么。
将所有$C_0(t),c_1(t),c_2(t),\cdots,c_N(t)$的导数堆在一起,不就是$x'(t)$了嘛
对于多项式基,$N+1$是基函数的个数(或者说$N$才是多项式最大阶数),也即RNN的state size,至于$n,k$只是一个序号,没有特殊含义。
November 28th, 2024
苏老师您好!我对这个地方有些疑问:“其中当k是偶数时χk=1,否则χk=0。”
如果依照公式,n选取为奇数,例如5,就会有
$$
p'_6(t)=11p_5(t)+7p_3(t)+\dots
$$
这里是否应该假设n也为偶数来凑成最终的式子?
另外在边界处还会剩余一个$p_1'$,这个应当借用公式消去
$$
p'_{n+1}(t)=(n+1)p_n(t)+tp_n'(t)\to p_1'(t)=p_1(t)+0
$$
没看懂你的疑问,式$\eqref{eq:leg-dot}$的成立是不区分奇偶的啊(即对奇偶的$n$都成立)。