从JL引理看熵不变性Attention
By 苏剑林 | 2023-04-10 | 30074位读者 |在《从熵不变性看Attention的Scale操作》、《熵不变性Softmax的一个快速推导》中笔者提出了熵不变性Softmax,简单来说就是往Softmax之前的Attention矩阵多乘上一个$\log n$,理论上有助于增强长度外推性,其中$n$是序列长度。$\log n$这个因子让笔者联系到了JL引理(Johnson-Lindenstrauss引理),因为JL引理告诉我们编码$n$个向量只需要$\mathcal{O}(\log n)$的维度就行了,大家都是$\log n$,这两者有没有什么关联呢?
熵不变性 #
我们知道,熵是不确定性的度量,用在注意力机制中,我们将它作为“集中注意力的程度”。所谓熵不变性,指的是不管序列长度$n$是多少,我们都要将注意力集中在关键的几个token上,而不要太过分散。为此,我们提出的熵不变性Attention形式为
\begin{equation}Attention(Q,K,V) = softmax\left(\frac{\log_{512} n}{\sqrt{d}}QK^{\top}\right)V\label{eq:core}\end{equation}
这里$Q,K\in\mathbb{R}^{n\times d}$。跟常规的Attention相比,就是scale的因子多了个$\log_{512} n$,其中底数取512,是假设我们所有的超参数(比如$d$)都是为训练长度512调好的。当然,即便你计划中的预训练长度不是512,底数也可以直接无脑取512,结果基本不会有什么影响。
这个形式的原理也很直观,当$n$增大时,意味着有更多的token去平摊了注意力,导致注意力不集中,此时我们乘上一个关于$n$单调递增的因子,softmax之后它实际上就相当于原来概率的幂运算,由于概率都小于1,所以概率越小幂运算之后会变得更小,这样注意力重新变得集中起来。至于这个因子为什么是对数的形式,那就需要看开头文章的推导过程了。
JL引理 #
JL引理,全称“Johnson-Lindenstrauss引理”,是关于向量嵌入的一个重要结论,简单来说它就是告诉我们“要塞下$n$个向量,只需$\mathcal{O}(\log n)$维空间”(这里的$\log$没有写出底数,默认都是以自然对数$e$为底),详细介绍可以参考《让人惊叹的Johnson-Lindenstrauss引理:理论篇》。
有意思的是,早在笔者知道JL引理之前,就在《最小熵原理(六):词向量的维度应该怎么选择?》推导过同样的、甚至更具体的结果——嵌入$n$个词向量,大致上需要$8\log n$维空间就行了。这个估计跟实际使用的维度很接近,比如$n$等于10万时,$8\log n$算出来大概是92,而我们经常用的词向量维度也是一两百维这个量级。
另外,JL引理还可以用来解释注意力机制的多头性。如果代入$n=512$,那么$8\log n\approx 50$,这跟Attention的Q、K常用的投影维度(也就是key_size,BERT里边是64,参考这里)很接近,这就告诉我们,如果序列长度时512,那么算Attention的Q、K的维度在50这个量级就够了,没必要用全部的hidden_size(BERT base是768),省下来的维度可以转而用来做多头注意力。
更多相关讨论可以参考《关于维度公式“n > 8.33 log N”的可用性分析》、《让人惊叹的Johnson-Lindenstrauss引理:应用篇》。
联系起来 #
现在,我们就可以尝试JL引理跟熵不变性Attention联系起来了。
我们将Q、K的key_size记为$d$,那么JL引理告诉我们,$d$的最佳选择应该是$d_n=\lambda \log n$,这里的$\lambda$是比例常数,具体是多少不重要。也就是说,理想情况下,$d$应该随着$n$的变化而变化,但很显然这样的设计并不容易实现,也不利于计算的并行化,所以实际情况下我们都只能使用固定的$d$。
假设我们选定了一个固定的$d$,并且假设这个$d$是为训练长度512设计的,那么我们可以得出$d = \lambda \log 512$,也就是$\lambda = \frac{d}{\log 512}$,以及
\begin{equation}d_n = \frac{d}{\log 512}\log n=d\log_{512} n\end{equation}
对于$n\neq 512$,理想情况下应该用$d_n$维的投影维度,但实际用了$d$维,根据内积的定义$\langle q,k\rangle = \sum\limits_{i=1}^d q_i k_i$,求和的项数正好等于维度数$d$,也就是说,理想情况下应该是$d_n$项求和,但实际上变为了$d$项求和,那么直觉上来看,如果每一项的贡献接近,那么我们将结果乘以$\frac{d_n}{d}$后,能够让结果更接近$d_n$项求和的理想情况,所以我们就得出,应当往$\langle q,k\rangle$中乘上因子
\begin{equation}\frac{d_n}{d} = \log_{512} n\end{equation}
来弥补实际情况与理想情况的差距。而常规的Scaled-Dot Attention乘上$\log_{512} n$后,正好是熵不变性Attention,也就是式$\eqref{eq:core}$。
这样,我们就将JL引理跟熵不变性Attention联系了起来。注意这只是个直观的、定性的理解过程,很难从定量角度将它进一步严格化,事实上也没有必要进一步定量化了,因为JL引理本身更多也只是一个定性的结论。
文章小结 #
本文构建了JL引理与熵不变性Attention之间的一个简单联系。
转载到请包括本文地址:https://kexue.fm/archives/9588
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Apr. 10, 2023). 《从JL引理看熵不变性Attention 》[Blog post]. Retrieved from https://kexue.fm/archives/9588
@online{kexuefm-9588,
title={从JL引理看熵不变性Attention},
author={苏剑林},
year={2023},
month={Apr},
url={\url{https://kexue.fm/archives/9588}},
}
April 11th, 2023
6666
April 18th, 2023
为什么GPT之类的LLM一般习惯于一步到位取和inner_dim一样的word_emb呢?出于工程上的方便优化么?
这个没有太多的工作去考证过它的优异性,不过早年ALBERT的结果显示对token embedding进行低秩分解基本不会损失性能。
感谢!我自己在尝试visual embedding接入LLM的时候,因为维数差异较大,纠结了很久。最后发现bridge这部分对于最终结果影响远没有数据大。
数据为王。
May 20th, 2023
「所谓熵不变性,指的是不管序列长度 $n$ 是多少,我们都要将注意力集中在关键的几个 token 上,而不要太过分散」
這樣做到目的,是不是想要得到一個 low-rank 的 attention map?
不是low-rank,而是sparse。
既然如此,一個簡單的方法是可以使用 $\mathrm{Sparsemax}$。
Sparsemax及其后来改进我都略有关注,存在复杂度略高的问题。其实现在主要的问题不是如何实现sparse,而是要sparsity不变性(不随着长度的变化而变化)
June 8th, 2023
苏神好:)
当n增大时,意味着有更多的token去平摊了注意力,导致注意力不集中。这句真是醍醐灌顶。
我觉得主要还是文本token的注意力存在局部性。所以当n增大,新增的token必然会稀释局部的注意力。
我在一个无向图场景下测了下熵不变性attention,表现还有所下降了。我是采样做的context。预测使用完整的context,并不会因为n增大而稀释注意力,因为新增的部分和原来的是同分布的,并不像文本那样(新增的都是远端token)。不知道我的理解是不是对。
“我在一个无向图场景下测了下熵不变性attention”,是训练阶段就加入,还是后处理加入呢?
独立同分布情况下,可能确实不进行log n缩放比较好,这种情况下有点像In-Context Learning,需要平均多个结果来做出准确预测,而不是单独关注某个case。
训练的时候我抽的固定大小的context。
我这两天又想了一下,应该还是有一个被稀释的问题,不过和seq场景不一样。被稀释的是self attention的self,我猜测自身的attention肯定会被稀释(至少和其它样本不一样),相当于是一个大小为1的窗口邻域(只包含它自身)。我试试把Attention矩阵的diag扣掉看看(用mask=~torch.eye(n, dtype=bool))
是因为你的场景下self特别重要吗?但即便如此,似乎也没法解释log n缩放会下降的问题。