书接上文,在上一篇文章《重温SSM(一):线性系统和HiPPO矩阵》中,我们详细讨论了HiPPO逼近框架其HiPPO矩阵的推导,其原理是通过正交函数基来动态地逼近一个实时更新的函数,其投影系数的动力学正好是一个线性系统,而如果以正交多项式为基,那么线性系统的核心矩阵我们可以解析地求解出来,该矩阵就称为HiPPO矩阵。

当然,上一篇文章侧重于HiPPO矩阵的推导,并没有对它的性质做进一步分析,此外诸如“如何离散化以应用于实际数据”、“除了多项式基外其他基是否也可以解析求解”等问题也没有详细讨论到。接下来我们将补充探讨相关问题。

离散格式 #

假设读者已经阅读并理解上一篇文章的内容,那么这里我们就不再进行过多的铺垫。在上一篇文章中,我们推导出了两类线性ODE系统,分别是:
\begin{align}
&\text{HiPPO-LegT:}\quad x'(t) = Ax(t) + Bu(t) \label{eq:legt-ode}\\[5pt]
&\text{HiPPO-LegS:}\quad x'(t) = \frac{A}{t}x(t) + \frac{B}{t}u(t) \label{eq:legs-ode}\end{align}
其中$A,B$是与时间$t$无关的常数矩阵,HiPPO矩阵主要指矩阵$A$。在这一节中,我们讨论这两个ODE的离散化。

输入转换 #

在实际场景中,输入的数据点是离散的序列$u_0,u_1,u_2,\cdots,u_k,\cdots$,比如流式输入的音频信号、文本向量等,我们希望用如上的ODE系统来实时记忆这些离散点。为此,我们先定义
\begin{equation}u(t) = u_k,\quad \text{如果} t\in[k\epsilon, (k + 1)\epsilon)\end{equation}
其中$\epsilon$就是离散化的步长。该定义也就是说在区间$[k\epsilon, (k + 1)\epsilon)$内,$u(t)$是一个常数函数,其值等于$u_k$。很明显这样定义出来的$u(t)$无损原本$u_k$序列的信息,因此记忆$u(t)$就相当于记忆$u_k$序列。

从$u_k$变换到$u(t)$,可以使得输入信号重新变回连续区间上的函数,方便后面进行积分等运算,此外在离散化的区间内保持为常数,也能够简化离散化后的格式。

LegT版本 #

我们先以LegT型ODE$\eqref{eq:legt-ode}$为例,将它两端积分
\begin{equation}x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} x(s)ds + B\int_t^{t+\epsilon}u(s)ds\end{equation}
其中$t=k\epsilon$。根据$u(t)$的定义,它在$[t, t + \epsilon)$区间内恒为$u_k$,于是$u(s)$的积分可以直接算出来:
\begin{equation}x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} x(s)ds + \epsilon B u_k\end{equation}
接下来的结果,就取决于我们如何近似$x(s)$的积分了。假如我们认为在$[t, t + \epsilon)$区间内$x(s)$近似恒等于$x(t)$,那么就得到前向欧拉格式
\begin{equation}x(t+\epsilon) - x(t) = \epsilon A x(t) + \epsilon B u_k \quad\Rightarrow\quad x(t+\epsilon) = (I + \epsilon A)x(t) + \epsilon B u_k\end{equation}
我们认为在$[t, t + \epsilon)$区间内$x(s)$近似恒等于$x(t+\epsilon)$,那么就得到后向欧拉格式
\begin{equation}x(t+\epsilon) - x(t) = \epsilon A x(t+\epsilon) + \epsilon B u_k \quad\Rightarrow\quad x(t+\epsilon) = (I - \epsilon A)^{-1}(x(t) + \epsilon B u_k)\end{equation}
前后向欧拉都具有相同的理论精度,但后向通常会有更好的数值稳定性。如果要更准确一些,那么认为在$[t, t + \epsilon)$区间内$x(s)$近似恒等于$\frac{1}{2}[x(t) + x(t+\epsilon)]$,那么得到双线性形式:
\begin{equation}\begin{gathered}
x(t+\epsilon) - x(t) = \frac{1}{2}\epsilon A [x(t) + x(t+\epsilon)] + \epsilon B u_k \\
\Downarrow \\
x(t+\epsilon) = (I - \epsilon A/2)^{-1}[(I + \epsilon A/2) x(t) + \epsilon B u_k]
\end{gathered}\end{equation}
这也等价于先用前向欧拉走半步,再用后向欧拉走半步。更一般地,我们还可以认为在$[t, t + \epsilon)$区间内$x(s)$近似恒等于$\alpha x(t) + (1 - \alpha) x(t+\epsilon)$,其中$\alpha\in[0,1]$,这就不进一步展开了。事实上,我们也可以完全不做近似,因为结合式$\eqref{eq:legt-ode}$以及在区间$[t,t+\epsilon)$中$u(s)$是常数$u_k$,我们完全可以用“常数变易法”来精确求解出来,结果是
\begin{equation}x(t+\epsilon) = e^{\epsilon A} x(t) + A^{-1} (e^{\epsilon A} - I) B u_k\label{eq:legt-ode-sol}\end{equation}
这里的矩阵指数按照级数来定义,可以参考《恒等式 det(exp(A)) = exp(Tr(A)) 赏析》

LegS版本 #

现在轮到LegS型ODE了,它的思路跟LegT型基本一致,结果也大同小异。首先将式$\eqref{eq:legs-ode}$两端积分得到
\begin{equation}x(t+\epsilon) - x(t) = A\int_t^{t+\epsilon} \frac{x(s)}{s}ds + B\int_t^{t+\epsilon}\frac{u(s)}{s}ds\end{equation}
根据$u(t)$定义,第二项积分的$u(s)$在$[t,t+\epsilon)$恒为$u_k$,所以它相当于$1/s$的积分,可以直接积分出来得$\ln\frac{t+\epsilon}{t}$,当然直接换为一阶近似$\frac{\epsilon}{t}$也无妨,因为本身$u_k$到$u(t)$的变换有很大自由度,这点误差无所谓。至于第一项积分,我们直接采用精度更高的中点近似,得到
\begin{equation}\begin{gathered}
x(t+\epsilon) - x(t) = \frac{1}{2}\epsilon A\left(\frac{x(t)}{t}+\frac{x(t+\epsilon)}{t+\epsilon}\right) + \frac{\epsilon}{t} B u_k \\[5pt]
\Downarrow \\[5pt]
x(t+\epsilon) = \left(I - \frac{\epsilon A}{2(t+\epsilon)}\right)^{-1}\left[\left(I + \frac{\epsilon A}{2t}\right)x(t) + \frac{\epsilon}{t} B u_k\right]
\end{gathered}\label{eq:legs-ode-bilinear}\end{equation}
事实上,式$\eqref{eq:legs-ode}$也可以精确求解,只需要留意到它等价于
\begin{equation}Ax(t) + Bu(t) = t x'(t) = \frac{d}{d\ln t} x(t)\end{equation}
这意味着只需要做变量代换$\tau = \ln t$,那么LegS型ODE就可以转化为LegT型ODE:
\begin{equation}\frac{d}{d\tau} x(e^{\tau}) = Ax(e^{\tau}) + Bu(e^{\tau})\end{equation}
利用式$\eqref{eq:legt-ode-sol}$得到(由于变量代换,时间间隔由$\epsilon$变成$\ln(t+\epsilon) - \ln t$)
\begin{equation}x(t+\epsilon) = e^{(\ln(t+\epsilon) - \ln t) A} x(t) + A^{-1} \big(e^{(\ln(t+\epsilon) - \ln t) A} - I\big) B u_k\label{eq:legs-ode-sol}\end{equation}
然而,上式虽然是精确解,但不如同为精确解的式$\eqref{eq:legt-ode-sol}$好用,因为式$\eqref{eq:legt-ode-sol}$的指数矩阵部分是$e^{\epsilon A}$,跟时间$t$无关,所以一次性计算完就可以了。但上式中$t$在矩阵指数里边,意味着在迭代过程中需要反复计算矩阵指数,对计算并不友好,所以LegS型ODE我们一般只会用式$\eqref{eq:legs-ode-bilinear}$来离散化。

优良性质 #

接下来,LegS是我们的重点关注对象。重点关注LegS的原因并不难猜,因为从推导的假设来看,它是目前求解出来的唯一一个能够记忆整个历史的ODE系统,这对于很多场景如多轮对话来说至关重要。此外,它还有其他的一些比较良好且实用的性质。

尺度等变 #

比如,LegS的离散化格式$\eqref{eq:legs-ode-bilinear}$是步长无关的,我们只需要将$t=k\epsilon$代入里边,并记$x(k\epsilon)=x_k$,就可以发现
\begin{equation}
x_{k+1} = \left(I - \frac{A}{2(k + 1)}\right)^{-1}\left[\left(I + \frac{A}{2k}\right)x_k + \frac{1}{k} B u_k\right]\end{equation}
步长$\epsilon$被自动地消去了,从而自然地减少了一个需要调的超参数,这对于炼丹人士显然是一个好消息。注意步长无关是LegS型ODE的一个固有性质,它跟具体的离散化方式并无直接关系,比如精确解$\eqref{eq:legs-ode-sol}$同样是步长无关的:
\begin{equation}x_{k+1} = e^{(\ln(k+1) - \ln k) A} x_k + A^{-1} \big(e^{(\ln(k+1) - \ln k) A} - I\big) B u_k\label{eq:legs-ode-sol-2}\end{equation}
其背后的原因,在于LegS型ODE满足“时间尺度等变性(Timescale equivariance)”——如果我们设$t=\lambda\tau$代入LegS型ODE,将得到
\begin{equation}Ax(\alpha\tau) + Bu(\alpha\tau) = (\alpha\tau)\times \frac{d}{d(\alpha\tau)} x(\alpha\tau) = \tau \frac{d}{d\tau}x(\alpha\tau)\end{equation}
这意味着,当我们将$u(t)$换成$u(\alpha t)$时,LegS的ODE形式并没有变化,而对应的解则是$x(t)$换成了$x(\alpha t)$。这个性质的直接后果就是:当我们选择更大的步长时,递归格式不需要发生变化,因为结果$x_k$的步长也会自动放大,这就是LegS型ODE离散化与步长无关的本质原因。

长尾衰减 #

LegS型ODE的另一个优良性质是,它关于历史信号的记忆是多项式衰减(Polynomial decay)的,这比常规RNN的指数衰减更缓慢,从而理论上能记忆更长的历史,更不容易梯度消失。为了理解这一点,我们可以从精确解$\eqref{eq:legs-ode-sol-2}$出发,从式$\eqref{eq:legs-ode-sol-2}$可以看到,每递归一步,历史信息的衰减效应可以用矩阵指数$e^{(\ln(k+1) - \ln k) A}$来描述,那么从第$m$步递归到第$n$步,总的衰减效应是
\begin{equation}\prod_{k=m}^{n-1} e^{(\ln(k+1) - \ln k) A} = e^{(\ln n - \ln m) A}\end{equation}
回顾HiPPO-LegS中$A$的形式:
\begin{equation}A_{n,k} = -\left\{\begin{array}{l}\sqrt{(2n+1)(2k+1)}, &k < n \\ n+1, &k = n \\
0, &k > n\end{array}\right.\end{equation}
从定义可以看出,$A$是一个下三角阵,其对角线元素为$-1,-2,-3,\cdots$。我们知道,三角阵的对角线元素正好是它的特征值(参考Triangular matrix),由此可以看到一个$d\times d$大小的$A$矩阵,有$d$个不同的特征值$-1,-2,\cdots,-d$,这说明$A$矩阵是可对角化的,即存在可逆矩阵$P$,使得$A = P^{-1}\Lambda P$,其中$\Lambda = \text{diag}(-1,-2,\cdots,-d)$,于是我们有
\begin{equation}\begin{aligned}
e^{(\ln n - \ln m) A} =&\, e^{(\ln n - \ln m) P^{-1}\Lambda P} \\
=&\, P^{-1} e^{(\ln n - \ln m) \Lambda}P \\
=&\, P^{-1}\,\text{diag}(e^{-(\ln n - \ln m)},e^{-2(\ln n - \ln m)},\cdots,e^{-d(\ln n - \ln m)})\,P \\
=&\, P^{-1}\,\text{diag}\Big(\frac{m}{n},\frac{m^2}{n^2},\cdots,\frac{m^d}{n^d}\Big)\,P \\
\end{aligned}\end{equation}
可见,最终的衰减函数是$1/n$的$1,2,\cdots,d$次函数的线性组合,所以LegS型ODE关于历史记忆至多是多项式衰减的,比指数衰减更加长尾,因此理论上有更好的记忆力。

计算高效 #

最后,我们指出HiPPO-LegS的$A$矩阵是计算高效(Computational efficiency)的。具体来说,直接按照矩阵乘法的朴素实现的话,一个$d\times d$的矩阵乘以$d\times 1$的列向量,需要做$d^2$次乘法,但LegS的$A$矩阵与向量相乘则可以降低到$\mathcal{O}(d)$次,更进一步地,我们还可以证明离散化后的$\eqref{eq:legs-ode-bilinear}$也可以在$\mathcal{O}(d)$完成。

为了理解这一点,我们首先将HiPPO-LegS的$A$矩阵等价地改写成
\begin{equation}A_{n,k} = \left\{\begin{array}{l}n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}, &k \leq n \\ 0, &k > n\end{array}\right.\end{equation}
对于向量$v = [v_0,v_1,\cdots,v_{d-1}]$,我们有
\begin{equation}\begin{aligned}
(Av)_n = \sum_{k=0}^n A_{n,k}v_k =&\, \sum_{k=0}^n \left(n\delta_{n,k} - \sqrt{2n+1}\sqrt{2k+1}\right)v_k \\
=&\, n v_n -\sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}v_k
\end{aligned}\end{equation}
这包含三种运算,第一项的$n v_n$是向量$[0,1,2,\cdots,d-1]$与$v$做逐位相乘运算,第二项的$\sqrt{2k+1}v_k$则是向量$[1,\sqrt{3},\sqrt{5},\cdots,\sqrt{2d-1}]$与$v$做逐位相乘,然后$\sum\limits_{k=0}^n$就是$\text{cumsum}$运算,最后乘以$\sqrt{2n+1}$就是再逐位相乘向量$[1,\sqrt{3},\sqrt{5},\cdots,\sqrt{2d-1}]$,每一步都可以在$\mathcal{O}(d)$内完成,因此总的复杂度是$\mathcal{O}(d)$的。

我们再来看$\eqref{eq:legs-ode-bilinear}$,它包含两步“矩阵-向量”乘法运算,一是$(I+\lambda A)v$,$\lambda$是任意实数,刚才我们已经证明了$Av$是计算高效的,自然$(I+\lambda A)v$也是;二是$(I-\lambda A)^{-1}v$,接下来我们将证明它也是计算高效的。这只需要留意到求$z=(I-\lambda A)^{-1}v$等价于解方程$v = (I-\lambda A)z$,利用上面给出的$Av$表达式,我们可以得到
\begin{equation}v_n = z_n - \lambda \left(n z_n - \sqrt{2n+1}\sum_{k=0}^n \sqrt{2k+1}z_k\right)\end{equation}
记$S_n = \sum\limits_{k=0}^n \sqrt{2k+1}z_k$,那么$z_n = \frac{S_n - S_{n-1}}{\sqrt{2n+1}}$,代入上式得
\begin{equation}v_n = \frac{S_n - S_{n-1}}{\sqrt{2n+1}} - \lambda \left(n \frac{S_n - S_{n-1}}{\sqrt{2n+1}} - \sqrt{2n+1}S_n\right)\end{equation}
整理得
\begin{equation}S_n = \frac{1 - \lambda n}{1+\lambda n + \lambda}S_{n-1} + \frac{\sqrt{2n+1}}{1+\lambda n + \lambda}v_n\end{equation}
这是一个标量的递归式,可以完全串行地计算,也可以利用Prefix Sum的相关算法并行计算(参考这里),计算复杂度为$\mathcal{O}(d)$或者$\mathcal{O}(d\log d)$,总之相比$\mathcal{O}(d^2)$都会更加高效。

傅立叶基 #

最后,我们以傅立叶基的一个推导收尾。在上一篇文章中,我们以傅立叶级数来引出了线性系统,但只推导了邻近窗口形式的结果,而后面的勒让德多项式基我们则推导了邻近窗口和完整区间两个版本(即LegT和LegS)。那么傅立叶基究竟能不能推导一个跟LegS相当的版本呢?其中会面临什么困难呢?下面我们对此进行探讨。

同样地,相关铺垫我们不再重复,按照上一节的记号,傅立叶基的系数为
\begin{equation}c_n(T) = \int_0^1 u(t_{\leq T}(s)) e^{-2i\pi n s}ds\end{equation}
跟LegS一样,为了记忆整个$[0,T]$区间的信号,我们需要一个$[0,1]\mapsto [0,T]$的映射,为此选取最简单的$t_{\leq T}(s)=sT$,代入后两边求导得到
\begin{equation}\frac{d}{dT}c_n(T) = \int_0^1 u'(sT) s e^{-2i\pi n s}ds\end{equation}
分部积分得到
\begin{equation}\begin{aligned}
\frac{d}{dT}c_n(T) =&\, \frac{1}{T}\int_0^1 s e^{-2i\pi n s}d u(sT) \\
=&\, \frac{1}{T} u(sT) s e^{-2i\pi n s}\big|_{s=0}^{s=1} - \frac{1}{T}\int_0^1 u(sT) d(s e^{-2i\pi n s})\\
=&\, \frac{1}{T} u(T) - \frac{1}{T}\int_0^1 u(sT) e^{-2i\pi n s} ds + \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds\\
=&\, \frac{1}{T} u(T) - \frac{1}{T}c_n(T) + \frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds\\
\end{aligned}\end{equation}
上一篇文章我们提到,HiPPO选取勒让德多项式为基的重要原因之一是$(s+1)p_n'(t)$可以分解为$p_0(t),p_1(t),\cdots,p_n(t)$的线性组合,而傅里叶基的$s e^{-2i\pi n s}$则不能做到这一点。但事实上,如果允许误差的话,这个断论是不成立的,因为我们同样可以将$s$分解为傅里叶级数:
\begin{equation}s = \frac{1}{2} + \frac{i}{2\pi}\sum_{k\neq 0} \frac{1}{k} e^{2i\pi k s}\end{equation}
这里的求和有无限项,如果要截断为有限项的话,就会产生误差,但我们可以先不纠结这一点,直接往上代入得到
\begin{equation}\begin{aligned}
&\,\frac{2i\pi n}{T}\int_0^1 u(sT) s e^{-2i\pi n s} ds \\
=&\, \frac{2i\pi n}{T}\int_0^1 u(sT) \left(\frac{1}{2} + \frac{i}{2\pi}\sum_{k\neq 0} \frac{1}{k} e^{2i\pi k s}\right) e^{-2i\pi n s} ds \\
=&\, \frac{i\pi n}{T}\int_0^1 u(sT) e^{-2i\pi n s} ds - \frac{1}{T}\sum_{k\neq 0} \frac{n}{k}\int_0^1 u(sT) e^{-2i\pi (n - k) s} ds \\
=&\, \frac{i\pi n}{T}c_n(T) - \frac{1}{T}\sum_{k\neq 0} \frac{n}{k}c_{n-k}(T) \\
=&\, \frac{i\pi n}{T}c_n(T) - \frac{1}{T}\sum_{k\neq n} \frac{n}{n - k}c_k(T) \\
\end{aligned}\end{equation}
这样一来
\begin{equation}
\frac{d}{dT}c_n(T) = \frac{1}{T} u(T) + \frac{i\pi n - 1}{T}c_n(T) - \frac{1}{T}\sum_{k\neq n} \frac{n}{n - k}c_k(T)\end{equation}
所以可以写出
\begin{equation}\begin{aligned}
x'(t) =&\, \frac{A}{t}x(t) + \frac{B}{t}u(t)\\[8pt]
\quad A_{n,k} =&\, \left\{\begin{array}{l}-\frac{n}{n-k}, &k \neq n \\ i\pi n - 1, &k = n\end{array}\right.\\[8pt]
B_n =&\, 1
\end{aligned}\end{equation}
实际使用的时候,我们只需要截断$|n|,|k|\leq N$,就可以得到一个$(2N+1)\times (2N+1)$的矩阵。截断带来的误差其实是无所谓的,因为我们在推导HiPPO-LegT的时候同样引入了有限级数近似,那会我们同样也没考虑误差,或者反过来讲,对于特定的任务,我们会选择适当的规模(即$N$的大小),而这个“适当”的含义之一,就是截断带来误差对于该任务是可以忽略的。

对大多数人来说,傅立叶基的这个推导可能还更容易理解一些,因为勒让德多项式对很多读者来说都比较陌生,尤其是LegT、LegS推导过程中用到的几个恒等式,而对于傅立叶级数大多数读者应该或多或少都有所了解。不过,从结果上来看,傅立叶基的这个结果可能不如LegS实用,一来它引入了复数,这增加了实现的复杂度,二来它推导出的$A$矩阵不像LegS那样是个相对较淡的下三角阵,因此理论分析起来也更为复杂。所以,大家权当它是一道深化对HiPPO的理解的练习题就好。

文章小结 #

在这篇文章中,我们补充探讨了上一篇文章介绍的HiPPO的一些遗留问题,其中包括如何对ODE进行离散化、LegS型ODE的一些优良性质,以及利用傅立叶基记忆整个历史区间的结果推导(即LegS的傅立叶版本),以求获得对HiPPO的更全面理解。

转载到请包括本文地址:https://kexue.fm/archives/10137

更详细的转载事宜请参考:《科学空间FAQ》

如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。

如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!

如果您需要引用本文,请参考:

苏剑林. (Jun. 05, 2024). 《重温SSM(二):HiPPO的一些遗留问题 》[Blog post]. Retrieved from https://kexue.fm/archives/10137

@online{kexuefm-10137,
        title={重温SSM(二):HiPPO的一些遗留问题},
        author={苏剑林},
        year={2024},
        month={Jun},
        url={\url{https://kexue.fm/archives/10137}},
}