初始化方法中非方阵的维度平均策略思考
By 苏剑林 | 2021-10-18 | 30605位读者 |在《从几何视角来理解模型参数的初始化策略》、《浅谈Transformer的初始化、参数化与标准化》等文章,我们讨论过模型的初始化方法,大致的思路是:如果一个$n\times n$的方阵用均值为0、方差为$1/n$的独立同分布初始化,那么近似于一个正交矩阵,使得数据二阶矩(或方差)在传播过程中大致保持不变。
那如果是$m\times n$的非方阵呢?常见的思路(Xavier初始化)是综合考虑前向传播和反向传播,所以使用均值为0、方差为$2/(m+n)$的独立同分布初始化。但这个平均更多是“拍脑袋”的,本文就来探究一下有没有更好的平均方案。
基础回顾 #
Xavier初始化是考虑如下的全连接层(设输入节点数为$m$,输出节点数为$n$)
\begin{equation} y_j = b_j + \sum_i x_i w_{i,j}\end{equation}
其中$b_j$一般初始化为0,$w_{i,j}$的初始化均值一般也为0,在《浅谈Transformer的初始化、参数化与标准化》中我们已经算得
\begin{equation}
\mathbb{E}[y_j^2] = \sum_{i} \mathbb{E}[x_i^2] \mathbb{E}[w_{i,j}^2]= m\mathbb{E}[x_i^2]\mathbb{E}[w_{i,j}^2]\end{equation}
所以为了保持二阶矩不变,我们将$w_{i,j}$的初始化方差设为$1/m$(均值为0时,方差等于二阶矩)。
但这个推导还只是考虑了前向传播,我们还需要使得模型有合理的梯度,那么还要使得模型在反向传播时也保持稳定。假设模型的损失函数为$l$,根据链式法则我们有
\begin{equation}\frac{\partial l}{\partial x_i} = \sum_j \frac{\partial l}{\partial y_j} \frac{\partial y_j}{\partial x_i}=\sum_j \frac{\partial l}{\partial y_j} w_{i,j}\end{equation}
注意这时是对$j$求和,求和的维度为$n$,所以在相同的假设下有
\begin{equation}
\mathbb{E}\left[\left(\frac{\partial l}{\partial x_i}\right)^2\right] = \sum_{j} \mathbb{E}\left[\left(\frac{\partial l}{\partial y_j}\right)^2\right] \mathbb{E}[w_{i,j}^2]= n \mathbb{E}\left[\left(\frac{\partial l}{\partial y_j}\right)^2\right]\mathbb{E}[w_{i,j}^2]\end{equation}
所以要保持反向传播的二阶矩不变,我们将$w_{i,j}$的初始化方差设为$1/n$。
一个是$1/m$,一个$1/n$,当$m\neq n$时就有冲突,但两个都同样重要,所以Xavier初始化就直接将两个维度平均一下,以$2/(m+n)$为方差进行初始化。
几何平均 #
现在让我们来考虑两个复合的全连接层(暂时忽略偏置项):
\begin{equation} y = xW_1 W_2
\end{equation}
其中$x\in\mathbb{R}^m,W_1\in\mathbb{R}^{m\times n},W_2\in\mathbb{R}^{n\times m}$,也就是说,输入是$m$维,变换为$n$维后再变换回$m$维,类似的操作比如BERT的FFN层(但FFN层中间多了个激活函数)。
根据前向传播的稳定性,我们应该要用$1/m$的方差初始化$W_1$、用$1/n$的方差初始化$W_2$。但是,如果我们要求$W_1$和$W_2$必须用同一方差初始化呢?那么很显然,为了保证$x,y$的方差不变,$W_1,W_2$都需要用方差为$1/\sqrt{mn}$的分布来初始化。如果考虑反向传播时,结果是相同的。
这样一来,我们就导出了一个新的维度平均策略:几何平均$\sqrt{mn}$。通过这个维度平均策略,我们可以使得在多层网络复合的时候,如果输入输出维度不变,那么方差就保持不变(不管前向传播还是反向传播)。而如果是代数平均$(m+n)/2$,假设$m < n$,那么根据$(m+n)^2/4\geq mn$,在前向/反向传播的时候方差就会缩小了。
二次平均 #
另外一个思考的角度是作为一个双重最小化问题:假设选用的方差为$t$,在前向传播时我们希望$(mt-1)^2$尽可能小,在反向传播时我们则希望$(nt-1)^2$尽可能小,所以综合考虑
\begin{equation}(mt-1)^2 + (nt-1)^2
\end{equation}
当$t=(m+n)/(m^2+n^2)$时,上式取到最小值,所以这得到了一个二次分式的平均方案:$(m^2+n^2)/(m+n)$。
容易证明:
\begin{equation}\frac{m^2+n^2}{m+n} \geq \frac{m+n}{2}\geq \sqrt{mn}\end{equation}
从推导过程上来看,左端的二次平均是希望每一步前向和反向传播的方差尽可能不变,因此可以认为左端是一个局部最优解;而右端的几何平均,则是希望“最初的输入”和“最终的输出”的方差尽量不变,因此可以认为右端某种意义上来说是一个全局最优解;而中间的代数平均,则是介乎全局最优和局部最优之间的一个解。
如此看来,似乎Xavier初始化“拍脑袋”的代数平均也不失为一个“中庸之道”的选择?
文章小结 #
本文简单思考了初始化方法中非方阵的维度平均方案,一直以来,大家似乎对默认的代数平均都没有什么疑问,而笔者从两种不同的角度得出了不同的平均策略的可能性。至于哪种平均策略更好,笔者也没有仔细做实验,有兴趣的读者自行尝试就好。当然,也可能在当前诸多优化策略之下,默认的初始化方案也工作得很好了,也就没有仔细调节的必要性了。
转载到请包括本文地址:https://kexue.fm/archives/8725
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Oct. 18, 2021). 《初始化方法中非方阵的维度平均策略思考 》[Blog post]. Retrieved from https://kexue.fm/archives/8725
@online{kexuefm-8725,
title={初始化方法中非方阵的维度平均策略思考},
author={苏剑林},
year={2021},
month={Oct},
url={\url{https://kexue.fm/archives/8725}},
}
October 20th, 2021
梦回均值不等式
April 27th, 2022
你好,在第二部分《几何平均》中,当$W_1,W_2$方差均为$\frac{m+n}2$时,前向传播的方差为何会减小呢,按照第一部分的推导方差$\mathbb{E}[y_j^2]$也会增大。
想了想,确实是我的问题,但正确的应该是前向传播和反向传播都会减小才对。