Naive Bayes is all you need ?
By 苏剑林 | 2023-06-08 | 54244位读者 |很抱歉,起了这么个具有标题党特征的题目。在写完《NBCE:使用朴素贝叶斯扩展LLM的Context处理长度》之后,笔者就觉得朴素贝叶斯(Naive Bayes)跟Attention机制有很多相同的特征,后来再推导了一下发现,Attention机制其实可以看成是一种广义的、参数化的朴素贝叶斯。既然如此,“Attention is All You Need”不也就意味着“Naive Bayes is all you need”了?这就是本文标题的缘由。
接下来笔者将介绍自己的思考过程,分析如何从朴素贝叶斯角度来理解Attention机制。
朴素贝叶斯 #
本文主要考虑语言模型,它要建模的是p(xt|x1,⋯,xt−1)。根据贝叶斯公式,我们有
p(xt|x1,⋯,xt−1)=p(x1,⋯,xt−1|xt)p(xt)p(x1,⋯,xt−1)∝p(x1,⋯,xt−1|xt)p(xt)
根据独立假设p(x1,⋯,xt−1|xt)=t−1∏j=1p(xj|xt),我们有
p(xt|x1,⋯,xt−1)∝t−1∏j=1p(xj|xt)p(xt)
再次根据贝叶斯公式p(xj|xt)=p(xt|xj)p(xj)p(xt)∝p(xt|xj)p(xt),得到
p(xt|x1,⋯,xt−1)∝1[p(xt)]t−2t−1∏j=1p(xt|xj)
两边取对数得到
logp(xt|x1,⋯,xt−1)=t−1∑j=1logp(xt|xj)−(t−2)logp(xt)+常数
一般化结果 #
相同的推导我们在《NBCE:使用朴素贝叶斯扩展LLM的Context处理长度》也进行过,跟该文章一样,我们将上式一般化为:
logp(xt|x1,⋯,xt−1)=(1+β)P[logp(xt|xj)]−βlogp(xt)+常数
这里的β作为超参数来调,P是某种Pooling方式。接下来我们主要看β=0、以加权平均为Pooling的例子,即
logp(xt|x1,⋯,xt−1)=∑jat,jlogp(xt|xj)+常数
这里的at,j是xt−1与xj的函数。
可能有读者想问,这个一般化的式子还能算是朴素贝叶斯吗?笔者认为它可以作为广义的朴素贝叶斯来看待,因为朴素贝叶斯可以视为各个logp(xt|xj)的等权平均,这里则是换成了更一般化的加权平均。不过,将at,j选取为xt−1与xj的函数,突出了xt−1的地位,改善了朴素贝叶斯的无序性这一弊端。所以更准确来说,式(6)是2-gram语言模型与朴素贝叶斯的结合。
注意力初现 #
接下来,将logp(xt|xj)进一步参数化,我们就可以得见Attention的形式了。不难发现,p(xt|xj)实质上就是以前Word2Vec的Skip Gram模型,它的常规建模方式是“Embedding + 内积 + Softmax”,即
p(xt|xj)=ev(xj)⋅w(xt)Z(xj),Z(xj)=∑xt∈Vocabev(xj)⋅w(xt)
所以我们简单地认为
logp(xt|xj)=v(xj)⋅w(xt)+常数
代入到式(6),得到
logp(xt|x1,⋯,xt−1)=(∑jat,jv(xj))⋅w(xt)+常数
括号中的式子,我们将它单独拿出来,当作通常用特征融合运算,它其实就是常规的Attention。所以说,单层的Attention做语言模型,实则就是广义的朴素贝叶斯。
当然,这里我们还没有将at,j确定下来。上一节我们说at,j是xt−1与xj的函数,然后同时还要归一化(加权平均),所以比较简单的方式就是像Skip Gram一样“Embedding + 内积 + Softmax”:
at,j=eq(xt−1)⋅k(xj)Zt,Zt=t−1∑j=1eq(xt−1)⋅k(xj)
代入到式(9),就是目前最常用的Dot-Product Attention了。当然,这种方式不是唯一的,还有加性Attention等,选择Dot-Product的最主要原因是它可以在比较省显存的前提下实现并行。
层叠与残差 #
不管怎么参数化,单层的朴素贝叶斯能力总是有限的,所以需要进一步提高模型的复杂度。从神经网络的角度来看,提高模型复杂度的主要方式是增加深度,也就是层与层之间的堆叠。那么,从概率分布的角度如何理解这种堆叠呢?答案是隐变量模型。
所谓隐变量模型,就是引入隐变量z1,z2,⋯,zt−1,使得
p(xt|x1,⋯,xt−1)=∫p(xt|z1,⋯,zt−1)p(z1,⋯,zt−1|x1,⋯,xt−1)dz1⋯dzt−1
说白了,就是通过简单分布的叠加来拟合更复杂的分布,跟GMM(高斯混合模型)的思想是一致的。基于前面的讨论,p(xt|z1,⋯,zt−1)我们同样用朴素贝叶斯建模,即从特征层面就是单层Attention。而对于p(z1,⋯,zt−1|x1,⋯,xt−1),我们按照自回归模型的特点,分解为
p(z1,⋯,zt−1|x1,⋯,xt−1)=t−1∏k=1p(zk|x1,⋯,xk)
这样每个p(zk|x1,⋯,xk)形式上就跟p(xt|z1,⋯,zt−1)一样了,于是同样可以用朴素贝叶斯建模。简单起见,zk我们定义为连续型变量,p(zk|x1,⋯,xk)则定义为狄拉克分布,于是积分可以直接算出来,结果就是两层Attention的堆叠了。
最后,Transfromer中还有一个关键成分是残差,实际上它就是将式(6)一般化为
logp(xt|x1,⋯,xt−1)=logp(xt|xt−1)+∑jat,jlogp(xt|xj)+常数
可以理解为一种突出了2-gram的地位的Pooling方式,算是一种先验。最后,还剩下的FeedForward层、LayerNorm层等,这些层不涉及token之间的交互,可以理解为是更复杂地参数化的朴素贝叶斯。
当然,这样笼统的解释看上去有些勉强,但笔者原本的想法,也不是精准地解释Transformer或Attention,而是期望是能从朴素贝叶斯角度来够获得一些关于长度外推的新思路。但很遗憾,目前笔者还没有得到预期的结果。然而,尽管看上去像是盲目的自恋,但笔者依然认为上述朴素贝叶斯和隐变量模型的视角还有进一步挖掘的潜力,比如看上去我们可以从朴素贝叶斯角度解释基于Attention的语言模型的In-Context Learning为啥会有效。
文章总概述 #
本文阐述了朴素贝叶斯与Attention机制之间的关联,显示了Attention可被视为一种广义的朴素贝叶斯。从这个视角,我们还可以进一步地理解Attention中的层叠与残差等内容。
转载到请包括本文地址:https://kexue.fm/archives/9648
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jun. 08, 2023). 《Naive Bayes is all you need ? 》[Blog post]. Retrieved from https://kexue.fm/archives/9648
@online{kexuefm-9648,
title={Naive Bayes is all you need ?},
author={苏剑林},
year={2023},
month={Jun},
url={\url{https://kexue.fm/archives/9648}},
}
June 8th, 2023
minior typo in Eq.(12):
p(z1,⋯,zt−1|x1,⋯,zt−1)→p(z1,⋯,zt−1|x1,⋯,xt−1)
fixed, thanks.
June 8th, 2023
前几天Hacker News上刷到了一样的title: https://arxiv.org/abs/2306.00238
Bytes Are All You Need: Transformers Operating Directly On File Bytes
一个是bayes,一个是bytes,貌似差得比较远。。。
哈哈,确实,看错了
June 8th, 2023
[...]Read More [...]
June 8th, 2023
跟第一篇文章对比,公式(5)右边第一项少了一个上横线
跟那篇文章的公式(8)对应。
June 9th, 2023
苏神好,想问个可能比较幼稚的问题,式(2)中的独立性假设是如何得到的呢?感觉对语言模型来说可能难以分解成独立的形式。
可能在冗长的上下文中,马尔可夫性远不如上下字那么重要,在这种情况的下,朴素贝叶斯已经可以将其建模出来了。
但这种建模对于语句块的顺序是无法展示的,还是缺失了一定的马尔可夫性导致的。
不是如何得到独立性,而是作为一个近似;它的近似程度肯定是不够好的,所以后面提出了隐变量模型,提高近似程度。举个例子,就好比ex≈1+x,这是一个近似,为了提高这个近似精度,我们可以补充更多的项如12x2(残差),也可以用迭代(层叠)。
June 9th, 2023
想请教一下苏神,毕竟naive bayes太有名了,以前也有人尝试过用来解决autoregressive model的问题,比如:
Shekofteh, Y., & Almasganj, F. (2013). Autoregressive modeling of speech trajectory transformed to the reconstructed phase space for ASR purposes. Digital Signal Processing, 23(6), 1923-1932.
还有一些其他的尝试,就不一一举例了。
考虑到transformer也是一个自回归模型,这之间又有什么诡异的联系呢?
你这篇paper下载要收费,我暂未能阅读。
至于你后一个问题,是不是想问这个 https://kexue.fm/archives/9648 ?
这里有一个PDF的,你看你那边能不能下载:
https://d1wqtxts1xzle7.cloudfront.net/54361764/j.dsp.2013.06.01120170906-20541-12kc2al-libre.pdf?1504748542=&response-content-disposition=inline%3B+filename%3DAutoregressive_modeling_of_speech_trajec.pdf&Expires=1686670350&Signature=TAi2LfLqomKU1t50R8Hkgrx19ssJ5VJxXeBFoxWCjnnJsqkyoEIorrir9gSzPr9LutGaq~VExuCgE07umdorOF5n00-Cosr1yJktr--5HQitV62p0H3GYo5V3KV58MWoIsOX8Na1tZjafGTEmHDIawachPTeMlSDQsMdpc~xG3wGoW8r7JKbnoqWWsT6WKtCIF67M~e3d2WY3GDhWRAVBf2uPPARBGSw6QXvGw5uHo1KOBcPEB2zi-05FVRN0L3MazQDyd1Z68C6dQd8xzDWYcs7nbSlIY-2o8JO8tQF~mQ2g5dxFoOjcphAwC4EJGjoGGQ7BBWzQTd1t09W2~pYfQ__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA
下不了。
看标题,似乎没看出它跟naive bayes的关系,你可以稍微介绍一下它的思路。
我通读了一下,这篇和你做的研究应该无关:
作者的目的是提取语音信号的有效特征。
特征提取方法应用于语音信号的帧上,包括以下五个步骤:
将帧中的语音样本归一化到该帧所属的均值和方差数值。因此,归一化后的语音样本具有零均值和单位标准差。
将归一化后的语音样本嵌入到重构相空间(RPS)中,使用嵌入维度 m = 8 和时间滞后参数为 6,从而为每个帧形成一个 m 维轨迹。
通过多元自回归(MVAR)方法评估语音轨迹在RPS中的 P 个系数矩阵(其中 P 是LP模型的阶数)。通过这种方式,为每个帧评估了维度为 m × m 的 P 个矩阵。
使用线性判别分析(LDA)进行维度降低,以获得最终的特征向量。LDA技术被用于同时去相关化和降低特征集的维度。
最后,使用所得特征向量作为输入,应用朴素贝叶斯分类器(NBC)对孤立音素进行分类。NBC计算每个测试样本属于不同音素类别的后验概率,并选择具有最高后验概率的类别作为预测结果。
他们是前面用Multivariate Autoregressive模型,然后用线性判别分析(LDA)进行降维,然后送到朴素贝叶斯分类器(NBC)里面去分类。
因为用到的类似的数学工具,分析连续语音又和NLP有有些类似,数学描述上和你有些公式比较象,读过后完全不同。
好的谢谢。
June 9th, 2023
苏神要不要粉丝中招几个研究生,让学生把相关论文都扫描一遍,否则自己干太累了,以免撞上类似的公式或者硬件的数学化表述,我自愿当一个。
很惭愧,我没什么组织能力(捂脸),而且很多paper不自己读一下总感觉不放心。
paper是我自己读的我自己都不放心,就得看苏神的讲解(doge
June 12th, 2023
前段时间正巧也在研究用贝叶斯的框架来解释语言模型in context learning的一些学习和泛化现象的工作,当时看到这篇文章: An Explanation of In-context Learning as Implicit Bayesian Inference (https://arxiv.org/pdf/2111.02080.pdf)
感觉和苏神的直觉有相似的地方?
谢谢推荐,我抽空读读。
June 16th, 2023
关于公式10,有两个疑问想请教一下:
1. a_tj,为什么是关于x[t-1]和x[j]的函数呢?为什么不是x[t]和x[j]的?
2. Z_t,为什么j是从1到t-1的?为什么不是整个词表长度上的?这个疑问,我同样觉得原始Attention中注意力score做softmax,是不是应该针对整个词表?
1、xt是你要预测的,不是输入,最新的token我们顶多是xt−1;
2、我们想要的是对输入token进行加权平均,不是词表中所有token。
September 5th, 2023
请教一下13式残差的设计:x是token序列,而transformer里面的残差结构是在层内,为什么残差在13式中被建模为logp(xt|xt−1) 而不是与z有关?
它只是给出单个层的形式,多个层的复合是应该改成z相关。