梯度视角下的LoRA:简介、分析、猜测及推广
By 苏剑林 | 2023-04-17 | 73573位读者 |随着ChatGPT及其平替的火热,各种参数高效(Parameter-Efficient)的微调方法也“水涨船高”,其中最流行的方案之一就是本文的主角LoRA了,它出自论文《LoRA: Low-Rank Adaptation of Large Language Models》。LoRA方法上比较简单直接,而且也有不少现成实现,不管是理解还是使用都很容易上手,所以本身也没太多值得细写的地方了。
然而,直接实现LoRA需要修改网络结构,这略微麻烦了些,同时LoRA给笔者的感觉是很像之前的优化器AdaFactor,所以笔者的问题是:能否从优化器角度来分析和实现LoRA呢?本文就围绕此主题展开讨论。
方法简介 #
以往的一些结果(比如《Exploring Aniversal Intrinsic Task Subspace via Prompt Tuning》)显示,尽管预训练模型的参数量很大,但每个下游任务对应的本征维度(Intrinsic Dimension)并不大,换句话说,理论上我们可以微调非常小的参数量,就能在下游任务取得不错的效果。
LoRA借鉴了上述结果,提出对于预训练的参数矩阵$W_0\in\mathbb{R}^{n\times m}$,我们不去直接微调$W_0$,而是对增量做低秩分解假设:
\begin{equation}W = W_0 + A B,\qquad A\in\mathbb{R}^{n\times r},B\in\mathbb{R}^{r\times m}\end{equation}
其中$A,B$之一用全零初始化,$W_0$固定不变,优化器只优化$A,B$。由于本征维度很小的结论,所以$r$我们可以取得很小,常见的是$r=8$,极端情况下我们甚至可以取$1$。所以说,LoRA是一种参数高效的微调方法,至少被优化的参数量大大降低了。
用MathJax直接画了个示意图:
$$\style{display: inline-block; width: 24ex; padding: 10ex 0; border: 1px solid #6C8EBF; background-color: #DAE8FC}{W_0\in\mathbb{R}^{n\times m}} \quad + \quad \style{display: inline-block; width: 8ex; padding: 10ex 0; border: 1px solid #D79B00; background-color: #FFE6CC}{A\in\mathbb{R}^{n\times r}}\quad\times\quad \style{display: inline-block; width: 24ex; padding: 3ex 0; border: 1px solid #D79B00; background-color: #FFE6CC}{B\in\mathbb{R}^{r\times m}}$$
梯度分析 #
正如《Ladder Side-Tuning:预训练模型的“过墙梯”》所提到的,很多参数高效的微调实际上只是降低了显存需求,并没有降低计算量。那么LoRA是否例外呢?它在显存和计算量方面的效率如何呢?下面我们来分析一下。
首先,我们知道训练模型所消耗的显存来源包括模型参数、模型梯度、模型激活值、优化器状态四部份,LoRA通过低秩分解降低了模型参数量,那么梯度和优化器状态也会随之降低,因此节省的显存是很明显的。那它能否节省计算量呢?
这取决于LoRA的实现方式,不同的实现方式计算梯度的复杂度不一样。LoRA的两种等效实现如下:
\begin{align}Y =&\, XW = X(W_0 + AB) \label{eq:lora-1}\\[5pt]
Y =&\, XW_0 + XAB = XW_0 + ZB \label{eq:lora-2}\end{align}
其中$X\in\mathbb{R}^{b\times n}$是模型输入,$Z=XA\in\mathbb{R}^{b\times r}$是中间输出。针对实现$\eqref{eq:lora-1}$,我们有
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W} B^{\top} = \left(X^{\top}\frac{\partial \mathcal{L}}{\partial Y}\right) B^{\top},\quad \frac{\partial \mathcal{L}}{\partial B} = A^{\top}\frac{\partial \mathcal{L}}{\partial W} = A^{\top}\left(X^{\top}\frac{\partial \mathcal{L}}{\partial Y}\right)\label{eq:grad-1}\end{equation}
$\mathcal{L}$是损失函数。很明显,这种实现导致的后果是需要算完整梯度$\frac{\partial \mathcal{L}}{\partial W}\in\mathbb{R}^{n\times m}$,然后才能算$A,B$的梯度,这意味着它比不LoRA还慢,也费显存。对于实现$\eqref{eq:lora-2}$,我们则有
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = X^{\top}\frac{\partial \mathcal{L}}{\partial Z} = X^{\top}\left(\frac{\partial \mathcal{L}}{\partial Y} B^{\top}\right),\quad \frac{\partial \mathcal{L}}{\partial B} = Z^{\top}\frac{\partial \mathcal{L}}{\partial Y} = (XA)^{\top}\frac{\partial \mathcal{L}}{\partial Y}\label{eq:grad-2}\end{equation}
此时的$Z,\frac{\partial \mathcal{L}}{\partial Z}\in\mathbb{R}^{b\times r}$,相比完整的梯度显然省了不少,计算复杂度也明显降低。所以,LoRA想要节省显存和计算最大化,关键是按照$\eqref{eq:lora-2}$而不是$\eqref{eq:lora-1}$来实现。
(注:关于矩阵计算梯度,我们可以根据链式法则和输出形状来“凑”,比如$\frac{\partial \mathcal{L}}{\partial A}$,根据链式法则我们知道它必然是$\frac{\partial \mathcal{L}}{\partial W}$和$B$以某种方式相乘,我们约定$\frac{\partial \mathcal{L}}{\partial A}$的形状跟$A$一致,即$n\times r$,想要用$\frac{\partial \mathcal{L}}{\partial W}$和$B$凑出一个$n\times r$的结果来,那就只有$\frac{\partial \mathcal{L}}{\partial W} B^{\top}$了。)
其他原因 #
除了低秩分解带来的好处外,如下几点也是LoRA能节省显存和提速的原因:
1、只更新了部分参数:比如LoRA原论文就选择只更新Self Attention的参数,实际使用时我们还可以选择只更新部分层的参数;
2、减少了通信时间:由于更新的参数量变少了,所以(尤其是多卡训练时)要传输的数据量也变少了,从而减少了传输时间;
3、采用了各种低精度加速技术,如FP16、FP8或者INT8量化等。
当然,这三部分原因确实能加快训练速度,但它们并不是LoRA所独有的,事实上几乎都有参数高效方法都具有这些特点。LoRA的突出优点是它的低秩分解很直观,在不少场景下跟全量微调的效果一致,以及在预测阶段可以直接把$W_0,A,B$合并成单个矩阵从而不增加推理成本。
优化视角 #
梯度$\eqref{eq:grad-1}$还告诉了我们如何从优化器角度来实现LoRA。优化器可以直接获取到全量梯度$\frac{\partial \mathcal{L}}{\partial W}$,然后我们只需要按照公式$\eqref{eq:grad-1}$对梯度进行投影,就得到$A,B$的梯度,接着就可以按照常规的优化器实现$A,B$的更新了。
假如优化器是SGD,那么就是
\begin{equation}\begin{aligned}
A_{t+1} =&\, A_t - \eta\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top},\quad B_{t+1} = B_t - \eta A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\\[5pt]
W_{t+1} =&\, W_0 + A_{t+1} B_{t+1} = W_t + (A_{t+1} B_{t+1} - A_t B_t)
\end{aligned}\end{equation}
如果是Adam之类的带滑动变量的优化器,则只需要滑动投影后的梯度,因此是降低了优化器的参数量,节省了一定的显存。模型越大,这部分参数所占的显存比例也就越大。
LoRA约定$A$或$B$之一使用全零初始化,这是为了保证初始状态模型跟预训练一致,但同时也带来了不对称问题(一个全零,一个非全零)。事实上,$A,B$都使用非全零初始化也是可以的,只需要事先将预训练权重减去$A_0 B_0$就行了,或者等价地说,将$W$参数化为
\begin{equation}W = W_0 - A_0 B_0 + A B\end{equation}
这样同时保持了初始状态一致,同时允许$A,B$都用非全零初始化,增强了对称性。
随机投影 #
如果我们将SGD场景下的更新量$A_{t+1} B_{t+1} - A_t B_t$展开,结果将是
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} B_t + A_t A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right) + \eta^2 \frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\end{equation}
假设$\eta^2$项是可以忽略的高阶项,那么就剩下
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} B_t^{\top} B_t + A_t A_t^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right)\end{equation}
从这个角度来看,相比全量微调的SGD,LoRA就是用括号中的结果替代了全量的梯度$\frac{\partial \mathcal{L}}{\partial W_t}$。
简单起见,接下来我们只关心$r=1$的情形,留意到在上式中,$t$时刻的投影向量$A_t,B_t$是依赖于$t$的,如果我们将它们换成不依赖于$t$的随机向量(每步训练都重新随机生成),那么会发生什么呢?我们考虑$u,v\sim\mathcal{N}(0,1)$,其中$u\in\mathbb{R}^{m\times 1}, v\in\mathbb{R}^{1\times n}$,那么更新量就变为
\begin{equation}- \eta\left(\frac{\partial \mathcal{L}}{\partial W_t} v^{\top} v + u u^{\top}\frac{\partial \mathcal{L}}{\partial W_t}\right)\end{equation}
可以证明的是
\begin{equation}\mathbb{E}_{u\sim \mathcal{N}(0,1)}[u u^{\top}] = I_{n\times n},\quad \mathbb{E}_{v\sim \mathcal{N}(0,1)}[v^{\top} v] = I_{m\times m}\end{equation}
这里的$I_{n\times n},I_{m\times m}$分别指$n\times n,m\times m$的单位矩阵。因此,跟“零阶梯度”类似,在平均意义下,这种每步都重新初始化的LoRA事实上等价于满秩的SGD。然而,真要按照这个方式实现的话,其速度甚至可能比满秩的SGD都要慢,所以它的目的不是提速,而是希望能缓解灾难遗忘问题——通过对单个(batch)样本使用低秩矩阵(而不是满秩)更新量的方式,减少对整个模型权重的影响。当然,这只是猜测,实际效果如何,笔者还没有实验过。
一个变体 #
同样还是先只考虑$r=1$的情形,LoRA相当于假设了$\Delta w_{i,j} = u_i v_j$,我们能不能做其他低秩分解假设呢?比如$\Delta w_{i,j} = u_i + v_j$?写成矩阵形式就是
\begin{equation}W = W_0 + A \mathbb{1}_{1\times m} + \mathbb{1}_{n\times 1} B,\qquad A\in\mathbb{R}^{n\times 1},B\in\mathbb{R}^{1\times m}\end{equation}
其中$\mathbb{1}_{1\times m},\mathbb{1}_{n\times 1}$分别指$1\times m,n\times 1$的全1矩阵。容易求出它的梯度是:
\begin{equation}\frac{\partial \mathcal{L}}{\partial A} = \frac{\partial \mathcal{L}}{\partial W} \mathbb{1}_{m\times 1},\quad \frac{\partial \mathcal{L}}{\partial B} = \mathbb{1}_{1\times n}\frac{\partial \mathcal{L}}{\partial W}\end{equation}
其实就是原本梯度的行求和与列求和。相比原版LoRA,这个加性分解有两个优点:1、加比乘计算量更低,梯度形式也更简单;2、$AB$的秩一定是1,但是$A \mathbb{1}_{1\times m} + \mathbb{1}_{n\times 1} B$的秩可能是2,如果秩代表了模型能力的话,那也就是说同样的参数量,加性的表达能力可能还更强。至于具体效果如何,后面笔者用到LoRA的时候,再做对比实验吧。
那么,加性分解能不能推广到$r > 1$的情形呢?自然是可以的,但稍微有些技巧。这里约定$m,n$都能被$r$整除,那么我们只需要将参数化方式改为
\begin{equation}W = W_0 + A I_{r(1\times m/r)} + I_{r(n/r\times 1)} B,\qquad A\in\mathbb{R}^{n\times r},B\in\mathbb{R}^{r\times m}\end{equation}
这里的$I_{r(1\times m/r)}$、$I_{r(n/r\times 1)}$分别指$1\times m/r$、$n/r\times 1$的分块矩阵,每一块则是$r\times r$的单位阵。这个形式说白了,就是分别将$A$、$B$看成是$n/r\times 1$、$1\times m/r$的分块矩阵,然后套用$r=1$的思路来操作。
文章小结 #
本文介绍了从梯度角度来理解LoRA,除了基本的介绍外,还包含了笔者的一些猜测和推广,供读者参考。
转载到请包括本文地址:https://kexue.fm/archives/9590
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Apr. 17, 2023). 《梯度视角下的LoRA:简介、分析、猜测及推广 》[Blog post]. Retrieved from https://kexue.fm/archives/9590
@online{kexuefm-9590,
title={梯度视角下的LoRA:简介、分析、猜测及推广},
author={苏剑林},
year={2023},
month={Apr},
url={\url{https://kexue.fm/archives/9590}},
}
April 17th, 2023
请问有没有计划在bert4keras上实现lora的计划?
主要是bert4keras还没适配真正大的模型,LoRA的意义似乎不大。
April 17th, 2023
计算开销为什么没有减少呢?文中对于 $W$ 的偏导在实际使用中其实应该对应着某层 activation, 即 $(W_{0;l}+A_lB_l)X_l$, 这部分中间量确实是 LoRA 和普通 fine-tuning 都要计算的,但是 LoRA 应该不需要通过对上一层 $W_{0;l+1}$ 的偏导计算它而只需要其参数值,其中节约的是对于所有 $W_{0;l}$ 参数的求导?此外,文中随机投影的部分的结论我个人感觉从另一个角度看更好理解。假设收敛,当将优化问题写作关于 $W_t$ 且 $t\to\infty$ 的问题而非关于 $W$ 时,如果不像 LoRA 那样保持 $W_0$ 不变,而是使 $W_t = W_{t-1} + A_{t-1}B_{t-1}$,那么 $W_t$ 的搜索空间是由 $\sum_tA_tB_t$ span 成的空间,当 $t\to\infty$ 时等价于 $W_0$ 所在空间,也即所找到的 optimal solution 与原秩问题一致。也就是说,如果不保持 $W_0$ 固定,而是低秩分解参数并累积更新,这其实是一种以时间换取 per iteration 计算量和内存开销的方法,在极端情况下即便 $r=1$ 也应该能得到和直接训练 $W$ 等价的效果。据说有人在常规训练 (即小模型 training from scratch) 上 (可能披着 federated learning 的皮) 尝试过这种方法,但是实际效果并不如直接对 $W$ 优化,个人怀疑是计算精度导致的问题。可以说涉及到低秩分解的方法中,直接将 $W$ 用低秩分解近似、LoRA, 乃至刚才所说的以低秩分解近似 fine-tuning 应该可以算是一种光谱,LoRA 很容易让人觉得它是在做参数的低秩近似,但其实并不是,因为说到底优化结束后的 $W^*=W_0+A^*B^*$ 并不是从低秩空间中搜索得到的,所以其价值恰恰来源于保证更新不远离初始化位置。由此我又想到了一些其他的 LoRA 变体,既然 LoRA 其实并不是在做低秩近似,那么 LoRA 参数也就没有必要和 $W$ 一一对应了,而可以考虑作为一种更复杂的 skip connection. 比如将 attention 中的 Q,K,B 视作一个整体,以一个 match 它输入输出维度的 AB 代替它,乃至将多层 attention 用低秩方法 skip connect 起来,这样可以进一步降低内存和计算开销。
关于计算量,我说的是求梯度的理论计算量不会减少,但实际的训练速度会提高,提高的原因文章也简单列举了。
$W_t = W_{t-1} + A_{t-1}B_{t-1}$这个形式之中,$A_{t-1}B_{t-1}$不看出显式的增量形式。当然,这个不是太重要,主要是“低秩分解”本身没有问题,但是不宜用复杂度太高的低秩分解,否则是得不偿失的,而本文的随机投影,算是最低成本的低秩分解了。事后来看,我倒是更看好文末提出的加性低秩分解,更少随机性,以及更加直观。
April 19th, 2023
我感覺這似乎和 matrix completion 有關聯。
April 22nd, 2023
即使不考虑文中说的几点原因,我觉得梯度的计算量还是少了的。我的理解是 $A^T\frac{\partial L}{\partial W}$中的$\frac{\partial L}{\partial W}$不用先算 (因为占显存),可以利用结合律先算 $A^T X$ (X是input)
嗯嗯,这点倒是有可能,受教了。
April 23rd, 2023
FEDPARA(arXiv:2108.06098)论文里有用两个矩阵的Hadamard积来做分解的方法,用两个$m \times R$和$R \times n$的矩阵可以达到最高$R^2$的秩。
这个思路也很有亮点,学习了。不过计算量上倒是比加性要大些。
April 25th, 2023
我想问一下,在计算机视觉中,用transformer结构搭建AE中的编解码器,有什么不好的地方吗?
我不熟悉,不知道有什么不好或者好的地方~
感谢苏老师的回复,看了您的文章真是受益匪浅!
April 30th, 2023
[...]苏剑林. (Apr. 17, 2023). 《梯度视角下的LoRA:简介、分析、猜测及推广 》[Blog post]. Retrieved from https://kexue.fm/archives/9590[...]
April 30th, 2023
[...]苏剑林. (Apr. 17, 2023). 《梯度视角下的LoRA:简介、分析、猜测及推广 》[Blog post]. Retrieved from https://kexue.fm/archives/9590[...]
May 1st, 2023
[...]苏剑林. (Apr. 17, 2023). 《梯度视角下的LoRA:简介、分析、猜测及推广 》[Blog post]. Retrieved from https://kexue.fm/archives/9590 4.https://github.com/huggingface/blog/blob/main/notebooks/HuggingFace_int8_demo.ipynb[...]
May 6th, 2023
你好苏神,我想问个问题,就是我读了LST那篇文章,但是那篇文章里只说了他的方法相比原来方法减少内存占用,但是没有提可以加速训练,并且我看实验图上也没有这样的实验。请问一下是具体在哪儿提到了?麻烦指教下
LST能加速不是很显然的嘛?再不济自己测一下也可以呀。