通过互信息思想来缓解类别不平衡问题
By 苏剑林 | 2020-07-19 | 182261位读者 |类别不平衡问题,也叫“长尾问题”,是机器学习面临的常见问题之一,尤其是来源于真实场景下的数据集,几乎都是类别不平衡的。大概在两年前,笔者也思考过这个问题,当时正好对“互信息”相关的内容颇有心得,所以构思了一种基于互信息思想的解决办法,但又想了一下,那思路似乎过于平凡,所以就没有深究。然而,前几天在arxiv上刷到Google的一篇文章《Long-tail learning via logit adjustment》,意外地发现里边包含了跟笔者当初的构思几乎一样的方法,这才意识到当初放弃的思路原来还能达到SOTA的水平~于是结合这篇论文,将笔者当初的构思过程整理于此,希望不会被读者嫌弃“马后炮”。
问题描述 #
这里主要关心的是单标签的多分类问题,假设有1,2,⋯,K共K个候选类别,训练数据为(x,y)∼D,建模的分布为pθ(y|x),那么我们的优化目标是最大似然,或者说最小化交叉熵,即
argminθE(x,y)∼D[−logpθ(y|x)]
通常来说,我们建立的概率模型最后一步都是softmax,假设softmax之前的结果为f(x;θ)(即logits),那么
−logpθ(y|x)=−logefy(x;θ)K∑i=1efi(x;θ)=log[1+∑i≠yefi(x;θ)−fy(x;θ)]
所谓类别不均衡,就是指某些类别的样本特别多,就好比“20%的人占据了80%的财富”一样,剩下的类别数很多,但是总样本数很少,如果从高到低排序的话,就好像带有一条很长的“尾巴”,所以叫做长尾现象。这种情况下,我们训练的时候采样一个batch,很少有机会采样到低频类别,因此很容易被模型忽略了低频类。但评测的时候,通常我们又更关心低频类别的识别效果,这便是矛盾之处。
常见思路 #
常见的思路大家应该也有所听说,大概就是三个方向:
1、从数据入手,通过过采样或降采样等手段,使得每个batch内的类别变得更为均衡一些;
2、从loss入手,经典的做法就是类别y的样本loss除以类别出现的频率p(y);
3、从结果入手,对正常训练完的模型在预测阶段做些调整,更偏向于低频类别,比如正样本远少于负样本,我们可以把预测结果大于0.2(而不是0.5)都视为正样本。
Google的原论文中对这三个方向的思路也列举了不少参考文献,有兴趣调研的读者可以直接阅读原论文,另外,知乎上的文章《Long-Tailed Classification (2) 长尾分布下分类问题的最新研究》也对该问题进行了介绍,读者也可以参考阅读。
学习互信息 #
回想一下,我们是怎么断定某个分类问题是不均衡的呢?显然,一般的思路是从整个训练集里边统计出各个类别的频率p(y),然后发现p(y)集中在某几个类别中。所以,解决类别不平衡问题的重点,就是如何把这个先验知识p(y)融入模型之中。
在之前构思词向量模型(如文章《更别致的词向量模型(二):对语言进行建模》)的时候,我们就强调过,相比拟合条件概率,如果模型能直接拟合互信息,那么将会学习到更本质的知识,因为互信息才是揭示核心关联的指标。但是拟合互信息没那么容易训练,容易训练的是条件概率,直接用交叉熵−logpθ(y|x)进行训练就行了。所以,一个比较理想的想法就是:如何使得模型依然使用交叉熵为loss,但本质上是在拟合互信息?
在公式(2)中,我们是建模了
pθ(y|x)=efy(x;θ)K∑i=1efi(x;θ)
现在我们改为建模互信息,那么也就是希望
logpθ(y|x)p(y)∼fy(x;θ)⇔logpθ(y|x)∼fy(x;θ)+logp(y)
按照右端的形式重新进行softmax归一化,那么就有pθ(y|x)=efy(x;θ)+logp(y)K∑i=1efi(x;θ)+logp(i),或者写成loss形式:
−logpθ(y|x)=−logefy(x;θ)+logp(y)K∑i=1efi(x;θ)+logp(i)=log[1+∑i≠yp(i)p(y)efi(x;θ)−fy(x;θ)]
原论文称之为logit adjustment loss。如果更加一般化,那么还可以加个调节因子τ:
−logpθ(y|x)=−logefy(x;θ)+τlogp(y)K∑i=1efi(x;θ)+τlogp(i)=log[1+∑i≠y(p(i)p(y))τefi(x;θ)−fy(x;θ)]
一般情况下,τ=1的效果就已经接近最优了。如果fy(x;θ)的最后一层有bias项的话,那么最简单的实现方式就是将bias项初始化为τlogp(y)。也可以写在损失函数中:
import numpy as np
import keras.backend as K
def categorical_crossentropy_with_prior(y_true, y_pred, tau=1.0):
"""带先验分布的交叉熵
注:y_pred不用加softmax
"""
prior = xxxxxx # 自己定义好prior,shape为[num_classes]
log_prior = K.constant(np.log(prior + 1e-8))
for _ in range(K.ndim(y_pred) - 1):
log_prior = K.expand_dims(log_prior, 0)
y_pred = y_pred + tau * log_prior
return K.categorical_crossentropy(y_true, y_pred, from_logits=True)
def sparse_categorical_crossentropy_with_prior(y_true, y_pred, tau=1.0):
"""带先验分布的稀疏交叉熵
注:y_pred不用加softmax
"""
prior = xxxxxx # 自己定义好prior,shape为[num_classes]
log_prior = K.constant(np.log(prior + 1e-8))
for _ in range(K.ndim(y_pred) - 1):
log_prior = K.expand_dims(log_prior, 0)
y_pred = y_pred + tau * log_prior
return K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
结果分析 #
很明显logit adjustment loss也属于调整loss方案之一,不同的是它是在log里边调整权重,而常规的思路则是在log外调整。至于它的好处,就是互信息的好处:互信息揭示了真正重要的关联,所以给logits补上先验分布的bias,能让模型做到“能靠先验解决的就靠先验解决,先验解决不了的本质部分才由模型解决”。
在预测阶段,根据不同的评测指标,我们可以制定不同的预测方案。从《函数光滑化杂谈:不可导函数的可导逼近》可以知道,对于整体准确率而言,我们有近似
整体准确率≈1NN∑i=1pθ(yi|xi)
其中{(xi,yi)}Ni=1是验证集。所以如果不考虑类别不均衡情况,追求更高的整体准确率,那么对于每个x我们直接输出pθ(y|x)最大的类别即可。但如果我们希望每个类的准确率都尽可能高,那么我们将上式改写成
整体准确率≈1NN∑i=1pθ(yi|xi)p(yi)×p(yi)=K∑y=1p(y)(1N∑xi∈Ωypθ(y|xi)p(y))
其中Ωy={xi|yi=y,i=1,2,⋯,N},也标签为y的x的集合,等号右边事实上就是先将同一个y的项合并起来。我们知道“整体准确率=每一类的准确率的加权平均”,而上式正好具有同样的形式,所以括号里边的1N∑xi∈Ωypθ(y|xi)p(y)就是“每一类的准确率”的一个近似了,因此,如果我们希望每一类的准确率都尽可能高,我们则要输出使得pθ(y|x)p(y)最大的类别(不加权)。结合pθ(y|x)的形式,我们有结论
y∗={argmaxyfy(x;θ)+τlogp(y),(追求整体准确率)argmaxyfy(x;θ),(希望每一类的准确率都尽可能均匀)
第一种其实就是输出条件概率最大者,而第二种就是输出互信息最大者,按具体需求选择。
至于详细的实验结果,大家可以自行看论文,总之就是好到有点意外:
文章小结 #
本文简单介绍了一种基于互信息思想的类别不平衡处理办法,该方案以前笔者也曾经构思过,不过没有深究,而最近Google的一篇论文也给出了同样的方法,遂在此简单记录分析一下,最后Google给出的实验结果显示该方法能达到SOTA的水平。
转载到请包括本文地址:https://kexue.fm/archives/7615
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jul. 19, 2020). 《通过互信息思想来缓解类别不平衡问题 》[Blog post]. Retrieved from https://kexue.fm/archives/7615
@online{kexuefm-7615,
title={通过互信息思想来缓解类别不平衡问题},
author={苏剑林},
year={2020},
month={Jul},
url={\url{https://kexue.fm/archives/7615}},
}
December 31st, 2020
这个loss我是不是可以这么直观的理解:对于多数类样本,在其logit上扣除一个较小的值;对于少数类样本在其logit上扣除较大的值。这样的效果,难道不该造成分类更容易偏向多数类样本么?苏神可否解释一下,上面有评论提到− τ\log π(y)的,我怎么觉得反而负号的话更说得通一些
训练用的是f_y(x;\theta)+\tau\log p(y),预测用的是f_y(x;\theta),你也可以理解为训练没变,预测减去了\tau\log p(y)。
可是原论文表3(也就是博主文章里最后那个表)里的试验结果似乎表示,post-hoc的方法还不如不post-hoc?是我理解有误么
从表3没看出哪里显示“post-hoc的方法还不如不post-hoc”。
表中最后第二行和第三行既然写明了post-hoc,而表的最后一行没有写上post-hoc,那应该指的是最后一行是没用post-hoc的logit adjusment方法。而最后一行的效果也是最好的,不是说明不加post-hoc反而会更有效吗
那是你理解错人家的表格了...
logit adjusment是一个统称,它包含两种方式:1、训练用的是f_y(x;\theta)+\tau\log p(y),预测用的是f_y(x;\theta),这种方式称为logit adjusment loss;2、训练用f_y(x;\theta),预测则用f_y(x;\theta)-\tau\log p(y),这种方式称为logit adjusment post-hoc。
所以表中后三行都显示logit adjusment的优势,后三行之间没有什么相互包含关系。
多谢博主耐心解答。不过还是有个疑惑,如果是这样,post-hoc与否照理说没什么差别吧,表格后三行表明post-hoc了是有不少区别的,论文里似乎也没提差别的原因...
@三五大锅|comment-15169
post-hoc和loss是实现logit adjusment的两种方式,区别肯定是有,但是多大区别、孰优孰劣,很难从理论上给出定量分析。就好比普通模型y=f(x)与残差模型y=x+f(x),假设f(x)具有万能拟合能力,那么两者理论上就是没区别的,但实际上效果差别大得很。
对于post-hoc,为了达到最优效果,需要精调\tau得到\tau^*,但这实际上不好实现。从表格上来看,还是loss的方式对超参数更为鲁棒一些。
April 4th, 2021
公式4是不是写错了?符号右边 应该是减去logp(y)吧。
没有写错。
(说“错”是需要理由的,不是凭感觉的。式(4)的左右两端就是可以互推的,哪里错了?跟你的直觉相反就是错的?)
我的意思是公式4中的波浪线是什么含义?是代表波浪线左边的服从右边的分布吗?还是正比于的意思?这个地方没太看明白。
\sim是正相关的意思,在这里你可以理解为正比于。
多谢苏老师,苏老师确实思考问题深入浅出,通俗易懂,在下实在佩服。
谢谢
April 6th, 2021
另外问下,函数logit adjustment的pytorch实现要怎么写?prior这个要如何设置呢?
prior直接用极大似然估计计算出来,这个先验P(y)懂了。
自己根据公式实现。
我已经知道怎么实现了,多谢苏老师,要知道我可不是伸手党。
我这里主要想表达的是“PyTorch请自便”,因为我不用。
我懂,我已经做出pytorch版本了。感觉苏老师的keras代码该写的,目前正在数据集上测试看效果。
请问您能发我一份pytorch版本的代码吗,chenzipeng_1105@163.com
您好,我也在尝试使用pytorch版本的使用,但是老是出错,能分享一下您的代码吗,感谢!liutingco@163.com
April 6th, 2021
你邮箱是多少?给你发一个邮件好不?
我自己也改了一下pytorch版本的,但是程序运行到一半总是有问题,能否看一下你的代码学习一下呢,邮箱1091867691@qq.com
April 20th, 2021
好像没发起,苏神您好,想问一下post-hoc和loss都属于logit adjustment方法,那用同一个模型在两个不同的数据集上分别使用这两个方法可以吗,在做本科毕设,所以不是很明白合不合理
没理解“用同一个模型在两个不同的数据集上分别使用这两个方法”是什么操作。
就是我的任务需要用一个模型在两个数据集上都要有提升,但是在一个数据集上我的模型用post-hoc形式调参有提升,而用loss形式提升不大;而另一个用loss形式有很大的提升,post-hoc调参提升不大。因为想着都是属于logit adjustment的方案,所以就想问问一个模型分别用了这两种方式还算是同一个模型吗?
那随你怎么用呀,效果好不就行了,为啥要做这个概念上的纠结?
April 22nd, 2021
噢噢好的 主要是担心不算是同一个模型就不符合规范 感谢苏神回复
September 15th, 2021
[...]PriorLoss: 通过互信息思想来缓解类别不平衡问题[...]
September 15th, 2021
[...]PriorLoss: 通过互信息思想来缓解类别不平衡问题[...]
January 21st, 2022
如果将最后一层fc的bias初始化为τlogp(y),模型不是会偏向频率高的类别吗?是否应该初始化为τlog(ε/p(y))(ε为常数)使模型偏向频率低的类别。
有这个问题的同学都是没有认真读文章的。
将最后一层的bias初始化为\tau\log p(y)也好,或者直接在logits上加上\tau\log p(y)也好,都是“训练”方案;“预测”阶段都是要把这个\tau\log p(y)去掉的。认真理解(9)式~
苏老师您好,为什么我的bert模型训练时y_y_pred加上τlogp(y),预测时去掉τlogp(y),效果会很差
我不清楚。我觉得你需要先认真理解(9)式,确定你追求的确实是去掉\tau\log p(y)的效果。
January 24th, 2022
如果τlogp(y)加上bias上,那么预测时也要去掉bias吗?
是的(如果你追求的是每一类的准确率都尽可能高)