多任务学习漫谈(一):以损失之名
By 苏剑林 | 2022-01-18 | 157894位读者 |能提升模型性能的方法有很多,多任务学习(Multi-Task Learning)也是其中一种。简单来说,多任务学习是希望将多个相关的任务共同训练,希望不同任务之间能够相互补充和促进,从而获得单任务上更好的效果(准确率、鲁棒性等)。然而,多任务学习并不是所有任务堆起来就能生效那么简单,如何平衡每个任务的训练,使得各个任务都尽量获得有益的提升,依然是值得研究的课题。
最近,笔者机缘巧合之下,也进行了一些多任务学习的尝试,借机也学习了相关内容,在此挑部分结果与大家交流和讨论。
加权求和 #
从损失函数的层面看,多任务学习就是有多个损失函数$\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_n$,一般情况下它们有大量的共享参数、少量的独立参数,而我们的目标是让每个损失函数都尽可能地小。为此,我们引入权重$\alpha_1,\alpha_2,\cdots,\alpha_n\geq 0$,通过加权求和的方式将它转化为如下损失函数的单任务学习
\begin{equation}\mathcal{L} = \sum_{i=1}^n \alpha_i \mathcal{L}_i\label{eq:w-loss}\end{equation}
在这个视角下,多任务学习的主要难点就是如何确定各个$\alpha_i$了。
初始状态 #
按道理,在没有任务先验和偏见的情况下,最自然的选择就是平等对待每个任务,即$a_i=1/n$。然而,事实上每个任务可能有很大差别,比如不同类别数的分类任务混合、分类与回归任务混合、分类与生成任务混合等等,从物理的角度看,每个损失函数的量纲和量级都不一样,直接相加是没有意义的。
如果我们将每个损失函数看成具有不同量纲的物理量,那么从“无量纲化”的思想出发,我们可以用损失函数的初始值倒数作为权重,即
\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{init})}}\label{eq:init}\end{equation}
其中$\mathcal{L}_i^{(\text{init})}$表示任务$i$的初始损失值。该式关于每个$\mathcal{L}_i$是“齐次”的,所以它的一个明显优点是缩放不变性,即如果让任务$i$的损失乘上一个常数,那么结果不会变化。此外,由于每个损失都除以了自身的初始值,较大的损失会缩小,较小的损失会放大,从而使得每个损失能够大致得到平衡。
那么,怎么估计$\mathcal{L}_i^{(\text{init})}$呢?最直接的方法当然是直接拿几个batch的数据来估算一下。除此之外,我们可以基于一些假设得到一个理论值。比如,在主流的初始化之下,我们可以认为初始模型(加激活函数之前)的输出是一个零向量,如果加上softmax则是均匀分布,那么对于一个“$K$分类+交叉熵”问题,它的初始损失就是$\log K$;对于“回归+L2损失”问题,则可以用零向量来估计初始损失,即$\mathbb{E}_{y\sim \mathcal{D}}[\Vert y-0\Vert^2] = \mathbb{E}_{y\sim \mathcal{D}}[\Vert y\Vert^2]$,$\mathcal{D}$是训练集的全体标签。
先验状态 #
用初始损失的一个问题是初始状态不一定能很好地反应当前任务的学习难度,更好的方案应该是将“初始状态”改为“先验状态”:
\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{prior})}}\label{eq:prior}\end{equation}
比如,如果$K$分类中每个类的频率分别是$[p_1,p_2,\dots,p_K]$(先验分布),那么虽然初始状态的预测分布为均匀分布,但我们可以合理地认为模型可以很容易学会将每个样本的结果都预测为$[p_1,p_2,\dots,p_K]$,此时模型的损失为熵
\begin{equation}\mathcal{L}_i^{(\text{prior})}=\mathcal{H} = -\sum_{i=1}^K p_i\log p_i\end{equation}
某种意义上来说,“先验分布”比“初始分布”更能体现出“初始”的本质,它是“就算模型啥都学不会,也知道按照先验分布来随机出结果”的体现,所以此时的损失值更能代表当前任务的初始难度,因此用$\mathcal{L}_i^{(\text{prior})}$代替$\mathcal{L}_i^{(\text{init})}$应该更加合理;类似地,对于“回归+L2损失”问题,它的先验结果应该是全体标签的期望$\mu = \mathbb{E}_{y\sim \mathcal{D}}[y]$,所以我们用$\mathcal{L}_i^{(\text{prior})}=\mathbb{E}_{y\sim \mathcal{D}}[\Vert y-\mu\Vert^2]$代替$\mathcal{L}_i^{(\text{init})}=\mathbb{E}_{y\sim \mathcal{D}}[\Vert y\Vert^2]$,有望取得更合理的结果。
动态调节 #
不管是用初始状态的式$\eqref{eq:init}$还是先验状态的式$\eqref{eq:prior}$,它们的任务权重在确定之后就保持不变了,并且它们确定权重的方法不依赖于学习过程。然而,尽管我们可以通过先验分布等信息简单感知一下学习难度,但究竟有多难其实要真正去学习才知道,所以更合理的方案应该是根据训练进程动态地调整权重。
实时状态 #
纵观前文,式$\eqref{eq:init}$和式$\eqref{eq:prior}$的核心思想都是用损失值的倒数来作为任务权重,那么能不能干脆用“实时”的损失值倒数来实现动态调整权重?即
\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}}\label{eq:sg}\end{equation}
这里的$\mathcal{L}_i^{(\text{sg})}$是$\text{stop_gradient}(\mathcal{L}_i)$的简写。在这个方案中,每个任务的损失函数都被调整恒为1,所以不管是量纲还是量级上都是一致的。由于$\text{stop_gradient}$算子的存在,虽然损失恒为1,但梯度并非恒为0:
\begin{equation}\nabla_{\theta}\left(\frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}}\right) = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}} = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i}\label{eq:sg-grad}\end{equation}
简单来说就是某个函数被$\text{stop_gradient}$算子包住后,就变成了一个新函数,其值与原来的函数恒等,但是它的导函数强制设为了0,所以最终结果就是以动态权重$1/\mathcal{L}_i$来实时调整了梯度的比例。很多“民间实验”表明,式$\eqref{eq:sg}$确实在多数情况下都可以作为一个相当不错的baseline。
等价梯度 #
我们可以从另一个角度来看该方案。从式$\eqref{eq:sg-grad}$我们可以得到
\begin{equation}\nabla_{\theta}\left(\frac{\mathcal{L}_i}{\mathcal{L}_i^{(\text{sg})}}\right) = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i} = \nabla_{\theta} \log \mathcal{L}_i\end{equation}
因此从梯度上看,式$\eqref{eq:sg}$与$\mathcal{L} = \sum\limits_{i=1}^n \log \mathcal{L}_i$没有实质区别,而我们进一步有
\begin{equation}\mathcal{L} = \sum_{i=1}^n \log \mathcal{L}_i = n\log \sqrt[n]{\prod_{i=1}^n\mathcal{L}_i}\end{equation}
由于$\log$是单调递增的,所以式$\eqref{eq:sg}$与下式在梯度方向上是一致:
\begin{equation}\mathcal{L} = \sqrt[n]{\prod_{i=1}^n\mathcal{L}_i}\end{equation}
广义平均 #
显然,上式正是$\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_n$的“几何平均”,而如果我们约定$a_i$恒等于$1/n$,那么原始的式$\eqref{eq:w-loss}$就是$\mathcal{L}_1,\mathcal{L}_2,\cdots,\mathcal{L}_n$的“代数平均”。也就是说,我们发现这一系列的推导其实隐藏了从代数平均到几何平均的转变,这启发我们或许可以考虑“广义平均”:
\begin{equation}\mathcal{L}(\gamma) = \sqrt[\gamma]{\frac{1}{n}\sum_{i=1}^n\mathcal{L}_i^{\gamma}}\end{equation}
也就是将每个损失函数算$\gamma$次方后再平均最后再开$\gamma$次方,这里的$\gamma$可以是任意实数,代数平均对应$\gamma=1$,而几何平均对应$\gamma=0$(需要取极限)。可以证明,$\mathcal{L}(\gamma)$是关于$\gamma$的单调递增函数,并且有
\begin{equation}\min(\mathcal{L}_1,\cdots,\mathcal{L}_n)=\lim_{\gamma\to-\infty} \mathcal{L}(\gamma) \leq\cdots\leq \mathcal{L}(\gamma) \leq\cdots\leq \lim_{\gamma\to+\infty} \mathcal{L}(\gamma)=\max(\mathcal{L}_1,\cdots,\mathcal{L}_n)\end{equation}
这就意味着,当$\gamma$增大时,模型愈发关心损失中的最大值,反之则更关心损失中的最小值。这样一来,虽然依然存在超参数$\gamma$要调整,但是相比于原始的式$\eqref{eq:w-loss}$,超参数的个数已经从$n$个变为只有1个,简化了调参过程。
平移不变 #
重新回顾式$\eqref{eq:init}$、式$\eqref{eq:prior}$和式$\eqref{eq:sg}$,它们都是通过每个任务损失除以自身的某个状态来调节权重,并且获得了缩放不变性。然而,尽管它们都具备了缩放不变性,但却失去了更基本的“平移不变性”,也就是说,如果每个损失都加上一个常数,$\eqref{eq:init}$、式$\eqref{eq:prior}$和式$\eqref{eq:sg}$的梯度方向是有可能改变的,这对于优化来说并不是一个好消息,因为原则上来说常数没有带来任何有意义的信息,优化结果不应该随之改变。
理想目标 #
一方面,我们用损失函数(的某个状态)的倒数作为当前任务的权重,但损失函数的导数不具备平移不变性;另一方面,损失函数可以理解为当前模型与目标状态的距离,而梯度下降本质上是在寻找梯度为0的点,所以梯度的模长其实也能起到类似作用,因此我们可以用梯度的模长来替换掉损失函数,从而将式$\eqref{eq:sg}$变成
\begin{equation}\mathcal{L} = \sum_{i=1}^n \frac{\mathcal{L}_i}{\Vert\nabla_{\theta}\mathcal{L}_i\Vert^{(\text{sg})}}\label{eq:grad}\end{equation}
跟损失函数的一个明显区别是,梯度模长显然具备平移不变性,并且分子分母关于$\mathcal{L}_i$依然是齐次的,所以上式还保留了缩放不变性。因此,这是一个能同时具备平移和缩放不变性的理想目标。
梯度归一 #
对式$\eqref{eq:grad}$求梯度,我们得到
\begin{equation}\nabla_{\theta}\mathcal{L} = \sum_{i=1}^n \frac{\nabla_{\theta}\mathcal{L}_i}{\Vert\nabla_{\theta}\mathcal{L}_i\Vert}\label{eq:grad-norm}\end{equation}
可以看到,式$\eqref{eq:grad}$本质上是将每个任务损失的梯度进行归一化后再把梯度累加起来。它同时也告诉了我们一种实现方案,即可以让每个任务依次训练,每次只训练一个任务,然后将每个任务的梯度归一化后累积起来再更新,这样就免除了在定义损失函数的时候就要算梯度的麻烦了。
关于梯度归一化,笔者能找到相关工作是《GradNorm: Gradient Normalization for Adaptive Loss Balancing in Deep Multitask Networks》,它本质上是式$\eqref{eq:init}$和式$\eqref{eq:grad-norm}$的混合,里边也包含了对梯度模长重新标定的思想,但却要通过额外的优化来确定任务权重,个人认为显得繁琐和冗余了。
本文小结 #
在损失函数的视角下,多任务学习的关键问题是如何调节每个任务的权重来平衡各自的损失,本文从缩放不变和平移不变两个角度介绍了一些参考做法,并补充了“广义平均”的概念,将多个任务的权重调节转化为单个参数的调节问题,可以简化调参难度。
转载到请包括本文地址:https://kexue.fm/archives/8870
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jan. 18, 2022). 《多任务学习漫谈(一):以损失之名 》[Blog post]. Retrieved from https://kexue.fm/archives/8870
@online{kexuefm-8870,
title={多任务学习漫谈(一):以损失之名},
author={苏剑林},
year={2022},
month={Jan},
url={\url{https://kexue.fm/archives/8870}},
}
January 19th, 2022
试了sum(log Li) 确实有用,期待 多任务学习漫谈(二)和公式(12)的实现。
January 20th, 2022
苏神,请教下12式的方法,如果是像bert这样的模型,每次要对110m的参数求解模长是不是太耗时了
接下来的“梯度归一”一节就说了,式$\eqref{eq:grad}$的最佳实现方式应该是梯度归一化加梯度累积,这只需要通过修改优化器来实现,不会明显增加计算量。
January 20th, 2022
「式(2)和式(3)的核心思想都是用损失值的倒数来作为任务权重」这个里边是不是「导数」而不是「倒数」?
很显然是“倒数”。
January 21st, 2022
苏神,(7)式子感觉有点疑惑,所以想请教一下
分母的L^{sg}我理解是不是一个确实的值了,虽然他的值等于L_{i},但是它已经不能对参数theta进行求导操作了,也就不能转成log了。假设进行梯度归一化后,logL_{i}=1,那么L_{i}=e,那么梯度就永远不会改变啦?望苏神轻喷
首先,$\text{stop_gradient}$是一个硬性的运算规则,即值不变,但是导数为0。所以式$\eqref{eq:sg-grad}$成立没问题吧?第二个等号是因为我不需要再求梯度了,所以$\mathcal{L}_i^{(\text{sg})}$跟$\mathcal{L}_i$没有任何区别。
另一方面$\nabla_{\theta} \log \mathcal{L}_i = \frac{\nabla_{\theta}\mathcal{L}_i}{\mathcal{L}_i}$也没问题吧,这是标准的复合函数求导而已。
如果上面都没有问题,所以$(7)$成立,不是很显然的吗?
苏神,还是觉得有点奇怪,我理解logL_{i}跟logL_{sg}值相等,这两者的区别是L_{sg}应该是不进行反向传播,而L_{i}会进行反向传播,那么(7)式中间的式子分子L_{i}对参数theta可导,而分母的L_{i}也就是L_{sg}对参数theta不可导(导数为0),那么分子可导,分母不可导,就不能用复式求导法则
你不要理解成可导不可导,你要理解为强制让导数为0,其他运算规则完全不变。
根据公式
$$\left(\frac{f(x)}{g(x)}\right)'=\frac{f'(x)g(x) - f(x)g'(x)}{g(x)^2}$$
强制让$g(x)$的导数$g'(x)$为0,结果就是
$$\frac{f'(x)}{g(x)}$$
所以也就意味着
$$\left(\frac{f(x)}{\text{sg}[g(x)]}\right)'=\frac{f'(x)}{g(x)}$$
再次强调,只是强制导数为0,其余都不变。
谢谢苏神,我模拟了一把,再结合你的公式推导,理解了
import tensorflow as tf
a = tf.Variable(1.0)
b = tf.Variable(3.0)
c = tf.add(a, b)
c_stoped = tf.stop_gradient(c)
d = tf.add(a, b)
e = tf.divide(d, c_stoped)
gradients_e = tf.gradients(e, xs=[a, b])
with tf.Session() as sess:
tf.global_variables_initializer().run()
print('(6)式结果:', sess.run(gradients_e))
a = tf.Variable(1.0)
b = tf.Variable(3.0)
c = tf.add(a, b)
d = tf.log(tf.add(a, b))
gradients_e = tf.gradients(d, xs=[a, b])
with tf.Session() as sess:
tf.global_variables_initializer().run()
print('(7)式结果:', sess.run(gradients_e))
(6)式结果: [0.25, 0.25]
(7)式结果: [0.25, 0.25]
恭喜~
感谢提问,在提问区整明白了
January 21st, 2022
苏神,你好,拜读完这篇文章深受启发,但我比较菜,有一个小白问题,就是多任务学习一般来说是为了学习到共有的知识或是更好地表示,那么对于像bert这种模型,我在微调阶段用多任务学习好吗?还是说多任务学习就应该用在有大量多个相关任务的标注数据时做才好?
理论上来说,这篇文章只是介绍“怎么做好多任务”,至于“应不应该做多任务”不是本文的主题。
当然,就你这个问题而言,多任务最理想的状态是“差异+互补”,即多个任务之间有差异,但也有明显的逻辑关联,可能通过多任务训练达到互补的效果。至于这种关联如何定量描述,我认为很难,所以我也没有很大的答案。
January 21st, 2022
您好,有个疑惑想要请教一下,请问如果同时使用式(5)和梯度裁剪的话,个人感觉梯度裁剪会对原本式(5)调整的梯度比例有影响,想问下您的看法。如果不用梯度裁剪缓解梯度爆炸的话,是否只能依靠细调学习率或者warm up去解决。
说实话,我认为梯度爆炸本身属于模型设计不合理的问题,所以我认为正确的处理方法应该是修改好模型(尤其是初始化)来修正梯度消失后,再来进行多任务或其他做法。
至于梯度裁剪,其实我没用过,所以我也没经验。但我觉得,如果真的出现了梯度爆炸,那么也许$(13)$式更适合你。
January 21st, 2022
如果不用多任务,交替训练(先训练任务A再训练任务B),因为调优是丢掉最后一层,如果不为了提升预测效果,不同任务同一bert模型迁移学习是不是也可以?
如果交替训练,并且能共用一个优化器(同一参数的滑动平均量也共享),那么多任务和交替训练其实区别不大。但如果优化器的参数不共享,那么我认为区别还是很大的,交替训练更容易过拟合。
January 28th, 2022
公式(13)中的梯度归一方法公式感觉有点问题。Loss一般在Loss下降后回传的梯度也是应该在下降(不同的Loss计算方式,对于不同的误差的惩罚程度是不太一样的,体现方式就是回传的梯度),如果像梯度归一的这种方式,不管梯度是大是小,都被归一化了,感觉失去了动态容忍关注的能力,对于一些已经计算得很好的样本,只要梯度没有归0,都会在梯度上依然造成很大的影响。我可能理解得有点问题,请苏神解惑。
显然梯度下降的更新量直接依赖于梯度,而梯度归一目的是让每个任务的梯度在混合之前保持一致的大小,不至于某些任务被其他任务“压倒”,符合每个任务都同等重要的初衷。
当然,梯度归一化后,直觉上的后果就是“更新几乎不会停止了”,因为每个任务的梯度都被归一化,无法自然地趋于0,所以我们要通过手动控制学习率来迫使更新停止。事实上诸如Adam的优化器本身也有类似的问题,它的分子分母都是关于梯度齐次的,所以理论上更新量就是常数量级的,也要手动调整学习率,所以这也不是一个很大的问题。
我还是感觉这种处理方式不是特别好,每个样本不管是否训练得已经特别好都会产生同等的作用,就算手动调整学习率这个问题还是会存在的,总感觉每个项前面应该再乘一个与Loss相关的调节系数,来平衡训练样本在不同阶段所起的作用。
不是归一化每个样本的梯度,是归一化每个任务的梯度,将每个任务平等对待。
January 29th, 2022
苏神大佬,有个疑惑请教一下,(7)式对于Li小于1的情况可以适用么?例如:
ctr_loss = tf.reduce_mean(tf.log(ctr_loss), 0)
ctcvr_loss = tf.reduce_mean(tf.log(ctcvr_loss), 0)
loss = tf.add(ctr_loss, ctcvr_loss, name="total_loss")
这里ctr_loss,ctcvr_loss都是小于1大于0的,然后我对loss进行梯度优化的过程会出现nan异常,所以这里不是很理解,谢谢.
是要每个任务算出自己的总loss后,再log然后求和。对于上述代码,即先reduce_mean然后再log。如果要防止nan,log的时候要加个epsilon。
February 8th, 2022
请教苏神,还是没有理解(12)式为什么具有平移不变性?对梯度做了归一之后,不同任务的梯度sum起来方向还是会变呀
什么叫做平移不变性?指的是损失函数加上一个常数,梯度不会产生变化。
然而,尽管它们都具备了缩放不变性,但却失去了更基本的“平移不变性”,也就是说,如果每个损失都加上一个常数,(2)、式(3)和式(5)的梯度方向是有可能改变的
-------------------------------------------------------
原文里面是这么说的,按我理解就是说公式12用梯度模长做分母的话梯度方向不变,这一点还是有些疑惑
“按我理解就是说公式12用梯度模长做分母的话梯度方向不变”,你觉得你这句话没毛病吗?什么叫做“用梯度模长做分母的话梯度方向不变”?什么叫做“不变”?
不变是指“在某些操作之下某些结果不会改变”。所以你要回答“某些操作”是什么,“某些结果”是什么。对于缩放不变性,“某些操作”就是缩放变换,“某些结果”就是梯度;对于平移不变形,“某些操作”就是平移变换,“某些结果”也是梯度。也就是说,我希望损失函数乘上一个非零常数,或者加上任意常数,梯度都不会变化。
“用梯度模长做分母的话梯度方向不变”,在你这句话中,“某些操作”是什么?这句话根本就逻辑不通吧...