不成功的尝试:将多标签交叉熵推广到“n个m分类”上去
By 苏剑林 | 2022-07-15 | 27554位读者 |可能有读者留意到,这次更新相对来说隔得比较久了。事实上,在上周末时就开始准备这篇文章了,然而笔者低估了这个问题的难度,几乎推导了整整一周,仍然还没得到一个完善的结果出来。目前发出来的,仍然只是一个失败的结果,希望有经验的读者可以指点指点。
在文章《将“Softmax+交叉熵”推广到多标签分类问题》中,我们提出了一个多标签分类损失函数,它能自动调节正负类的不平衡问题,后来在《多标签“Softmax+交叉熵”的软标签版本》中我们还进一步得到了它的“软标签”版本。本质上来说,多标签分类就是“n个2分类”问题,那么相应的,“n个m分类”的损失函数又该是怎样的呢?
这就是本文所要探讨的问题。
类比尝试 #
在软标签推广的文章《多标签“Softmax+交叉熵”的软标签版本》中,我们是通过直接将“n个2分类”的sigmoid交叉熵损失,在log内做一阶截断来得到最终结果的。同样的过程确实也可以推广到“n个m分类”的softmax交叉熵损失,这是笔者的第一次尝试。
记softmax(si,j)=esi,j∑jesi,j,si,j为预测结果,而ti,j则为标签,那么
−∑i∑jti,jlogsoftmax(si,j)=∑i∑jti,jlog(1+∑k≠jesi,k−si,j)=∑jlog∏i(1+∑k≠jesi,k−si,j)ti,j=∑jlog(1+∑iti,j∑k≠jesi,k−si,j+⋯)
对i的求和默认是1∼n,对j的求和默认是1∼m。截断⋯的高阶项,得到
l=∑jlog(1+∑i,k≠jti,je−si,j+si,k)
这就是笔者开始得到的loss,它是之前的结果到“n个m分类”的自然推广。事实上,如果ti,j是硬标签,那么该loss基本上没什么问题。但笔者希望它像《多标签“Softmax+交叉熵”的软标签版本》一样,对于软标签也能得到推导出相应的解析解。为此,笔者对它进行求导:
∂l∂si,j=−ti,je−si,j∑k≠jesi,k1+∑i,k≠jti,je−si,j+si,k+∑h≠jti,he−si,hesi,j1+∑i,k≠hti,he−si,h+si,k
所谓解析解,就是通过方程∂l∂si,j=0来解出。然而笔者尝试了好几天,都求不出方程的解,估计并没有简单的显式解,因此,第一次尝试失败。
结果倒推 #
尝试了几天实在没办法后,笔者又反过来想:既然直接类比出来的结果无法求解,那么我干脆从结果倒推好了,即先把解确定,然后再反推方程应该是怎样的。于是,笔者开始了第二次尝试。
首先,观察发现原来的多标签损失,或者前面得到的损失(2),都具有如下的形式:
l=∑jlog(1+∑iti,je−f(si,j))
我们就以这个形式为出发点,求导
∂l∂si,k=∑j−ti,je−f(si,j)∂f(si,j)∂si,k1+∑iti,je−f(si,j)
我们希望ti,j=softmax(f(si,j))=ef(si,j)/Zi就是∂l∂si,k=0的解析解,其中Zi=∑jef(si,j)。那么代入得到
0=∂l∂si,k=∑j−(1/Zi)∂f(si,j)∂si,k1+∑i1/Zi=−(1/Zi)∂(∑jf(si,j))∂si,k1+∑i1/Zi
所以要让上式自然成立,我们发现只需要让∑jf(si,j)等于一个跟i,j都无关的常数。简单起见,我们让
f(si,j)=si,j−ˉsi,ˉsi=1m∑jsi,j
这样自然地有∑jf(si,j)=0,对应的优化目标就是
l=∑jlog(1+∑iti,je−si,j+ˉsi)
ˉsi不影响归一化结果,所以它的理论最优解是ti,j=softmax(si,j)。
然而,看上去很美好,然而它实际上的效果会比较糟糕,ti,j=softmax(si,j)确实是理论最优解,但实际上标签越接近硬标签,它的效果会越差。因为我们知道对于损失(8)来说,只要si,j≫ˉsi,损失就会很接近于0,而要达到si,j≫ˉsi,si,j不一定是si,1,si,2,⋯,si,m中的最大者,这就无法实现分类目标了。
思考分析 #
现在我们得到了两个结果,式(2)是原来多标签交叉熵的类比推广,它在硬标签的情况下效果还是不错的,但是由于求不出软标签情况下的解析解,因此软标签的情况无法做理论评估;式(8)是从结果理论倒推出来的,理论上它的解析解就是简单的softmax,但由于实际优化算法的限制,硬标签的表现通常很差,甚至无法保证目标logits是最大值。特别地,当m=2时,式(2)和式(8)都能退化为多标签交叉熵。
我们知道,多标签交叉熵能够自动调节正负样本不平衡的问题,同样地,虽然我们目前还没能得到一个完美的推广,但理论上推广到“n个m分类”后依然能够自动调节m个类的不平衡问题。那么平衡的机制是怎样的呢?其实不难理解,不管是类比推广的式(2),还是一般的假设式(4),对i的求和都放在了log里边,原本每个类的损失占比大体上是正比于“该类的样本数”的,改为放在了log里边求和后,每个类的损失占就大致等于“该类的样本数的对数”,从而缩小了每个类的损失差距,自动缓解了不平衡问题。
遗憾的是,本文还没有得出关于“n个m分类”的完美推广——它应该包含两个特性:1、通过log的方法自动调节类别不平衡现象;2、能够求出软标签情况下的解析解。对于硬标签来说,直接用式(2)应该是足够了;而对于软标签来说,笔者实在是没辙了,欢迎有兴趣的读者一起思考交流。
文章小结 #
本文尝试将之前的多标签交叉熵推广到“n个m分类”上去,遗憾的是,这一次的推广并不算成功,暂且将结果分享在此,希望有兴趣的读者能一起参与改进。
转载到请包括本文地址:https://kexue.fm/archives/9158
更详细的转载事宜请参考:《科学空间FAQ》
如果您还有什么疑惑或建议,欢迎在下方评论区继续讨论。
如果您觉得本文还不错,欢迎分享/打赏本文。打赏并非要从中获得收益,而是希望知道科学空间获得了多少读者的真心关注。当然,如果你无视它,也不会影响你的阅读。再次表示欢迎和感谢!
如果您需要引用本文,请参考:
苏剑林. (Jul. 15, 2022). 《不成功的尝试:将多标签交叉熵推广到“n个m分类”上去 》[Blog post]. Retrieved from https://kexue.fm/archives/9158
@online{kexuefm-9158,
title={不成功的尝试:将多标签交叉熵推广到“n个m分类”上去},
author={苏剑林},
year={2022},
month={Jul},
url={\url{https://kexue.fm/archives/9158}},
}
July 16th, 2022
假設n為1,那麼問題就退化為1個m分類,此處m為單選。
那麼loss應該會回到最原始的softmax
−logestn∑i=1esi=−st+logn∑i=1esi
然而從(4)式開始仔細看,
l=∑jlog(1+∑iti,je−f(si,j))...(4)
當n=1,i的部分就可以當作沒有,變成下式:
l=∑jlog(1+tje−f(sj))...(4.1)
假如t是硬標籤{0,1},就會變成:
l=log(1+e−f(st))...(4.2)
這形式是不是很像 [将“softmax+交叉熵”推广到多标签](https://kexue.fm/archives/7359) 中的(6)式:
−logestn∑i=1esi=−log1n∑i=1esi−st=logn∑i=1esi−st=log(1+n∑i=1,i≠tesi−st)...(6)
只是少了sum over i,這裡的i是類別數。
因此,從這裡反推回去,我覺得(4)式應該改成:
當n=1時,
l=log(1+∑j,j≠tesj−st)...(4.3)
=log(1+∑j(1−tj)esj)...(4.3.1)
其中j表示類別,st 表示目標類的得分,tj為軟標籤。
當n>1時,應該改成:
l=∑ilog(1+∑i,j(1−ti,j)esi,j)...(4.4)
然而當m=2時,就退化成n個2分類,
l=∑ilog(1+∑i((1−ti,0)esi,0+(1−ti,1)esi,1))...(4.5)
回到 [多标签“Softmax+交叉熵”的软标签版本](https://kexue.fm/archives/9064) 式(5):
log(1+∑itie−si)+log(1+∑i(1−ti)esi)...(5)
整理一下,得到:
=log((1+∑itie−si)(1+∑i(1−ti)esi))...(5.1)
=log(1+∑itie−si+∑i(1−ti)esi+∑itie−si∑i(1−ti)esi)...(5.2)
=log(1+∑i(tie−si+(1−ti)esi)+...)...(5.3)
可以發現跟(4.5)的差別只在
1. log前面沒有sum over i
2. log中多了高階項
因此,可以將(4.4)改造成:
=log(m∏j=1(1+n∑i=1(1−ti,j)esi,j))...(4.6)
=log(1+∑j∑i(1−ti,j)esi,j+...)...(4.7)
式(4.7)至少符合了上面兩種退化的情況,分別是「1個m分類,m為單選」及「n個2分類」。
如果要再擴展到「1個m分類,m為多選」,則可以改成這樣:
=log(∏j∈Ωneg(1+n∑i=1(1−ti,j)esi,j)∏j∈Ωpos(1+n∑i=1(ti,j)e−si,j))...(4.8)
因為乘法的關係,所以每一項都會比較到,就像[将“softmax+交叉熵”推广到多标签分类问题](https://kexue.fm/archives/7359)中的式(8)。再把它展開來:
=log((1+∑j∈Ωneg∑i(1−ti,j)esi,j+...)(1+∑j∈Ωpos∑i(ti,j)e−si,j+...))...(4.9)
=log(1+∑j∈Ωneg∑i(1−ti,j)esi,j+∑j∈Ωpos∑i(ti,j)e−si,j+...)...(4.10)
我認為這應該就是最一般的形式了。
雖然感覺像是湊出來的,但是至少都滿足了幾種退化的情況。
所不確定的是,是否引入高階項會讓結果更好?
July 16th, 2022
從(4.9)開始,若考慮到neg跟pos交叉項:
=log((1+∑j∈Ωneg∑i(1−ti,j)esi,j+...)(1+∑j∈Ωpos∑i(ti,j)e−si,j+...))...(4.9)
=log(1+∑j∈Ωneg∑i(1−ti,j)esi,j+∑j∈Ωpos∑i(ti,j)e−si,j+(∑j∈Ωneg∑i(1−ti,j)esi,j)(∑j∈Ωpos∑i(ti,j)e−si,j)...)...(4.11)
若令n=1,t為硬標籤{0,1},則退化成這樣:
=log(1+∑j∈Ωnegesj+∑j∈Ωpose−sj+(∑j∈Ωnegesj)(∑j∈Ωpose−sj)...)...(4.12)
而其中的
(∑j∈Ωnegesj)(∑j∈Ωpose−sj)...(4.12.1)
展開來就是:
=(esneg1+esneg2+esneg3+...)(e−spos1+e−spos2+e−spos3+...)...(4.12.2)
=esneg1−spos1+esneg1−spos2+esneg1−spos3+...+esneg2−spos1+esneg2−spos2+esneg2−spos3+...+esneg3−spos1+esneg3−spos2+esneg3−spos3+......(4.12.3)
這裡就出現了跟 [将“softmax+交叉熵”推广到多标签分类问题](https://kexue.fm/archives/7359) 式(7) 類似的正負樣本兩兩相減的形式。
而最小化這些正負樣本的交叉項,就確保了負樣本得分不會高於正樣本。
當m為單選時,表示pos只有一項,記為t,那就應該會退化成一般的softmax
=log(1+∑j∈Ωnegesj+e−st+∑j∈Ωnegesj−st...)...(4.12.4)
可以看到比[将“softmax+交叉熵”推广到多标签](https://kexue.fm/archives/7359) 中的(6)式多了一階項:
∑j∈Ωnegesj+e−st...(4.12.5)
考慮到最小化這些一階項的作用其實已經包含到交叉項裡面,
因此,式(4.11)可以再簡化,把一階項拿掉:
l=log(1+(∑j∈Ωneg∑i(1−ti,j)esi,j)(∑j∈Ωpos∑i(ti,j)e−si,j))...(4.11.1)
我覺得可以做實驗比較看看(4.11)跟(4.11.1)哪個較好?
不過我個人認為,在正類別數量K固定的情況下,可能(4.11.1)保留正負樣本的交叉項就夠了。因為我們希望的是「非目標項得分盡可能小於目標項」,只要知道相對大小即可。
但如果是正類別K數量不是固定的,需要設定一個threshold,那麼(4.11)中的一階項,就有必要了。因為一階項的作用似乎就是限制分數的絕對值,讓正樣本得分盡可能大於0,負樣本小於0。這樣就可以以0為基準,拿出分數大於0的K類。
July 16th, 2022
從 (4.12.5) 來看,最小化這個式子會讓正類的 sj>0,負類的 sj<0。
但若我們希望不要從0開始,而有一個偏移b呢?那就可以改成:
∑j∈Ωnegesj−b+e−(st−b)...(4.12.5.1)
這樣會鼓勵正類的sj>b,負類的sj<b。
那如果我們還想加入一個間隔(margin)呢?可以加上m:
∑j∈Ωnege(sj−b)+m+e−(st−b)+m...(4.12.5.2)
這樣會鼓勵正類的sj>b+m,負類的sj<b−m。而從式(4.12.3)可知,正負兩類間隔相差2m。
上面這裡的b跟m都是一個固定的數值,現在考慮把軟標籤放回來:
∑j∈Ωneg(1−tj)e(sj−b)+m+tpose−(st−b)+m...(4.12.5.3)
=∑j∈Ωnege(sj−b)+m+ln(1−tj)+e−(st−b)+m+ln(tpos)...(4.12.5.4)
任取一個負類,正負類的差距為:
((sneg−b)+m+ln(1−tneg))+(−(spos−b)+m+ln(tpos))
=sneg−spos+2m+ln(1−tneg)+ln(tpos)
=sneg−spos+2m+ln((1−tneg)tpos)
由於m是固定的,可知軟標籤t的作用為根據樣本的分錯程度動態調整margin。
當 tneg=0.9, tpos=0.1時,有margin 2m+ln(0.01);
當 tneg=tpos=0.5時,有margin 2m+ln(0.25);
當 tneg=0.1, tpos=0.9時,有margin 2m+ln(0.81);
可以看到,當樣本分錯的很離譜的時候(tneg=0.9, tpos=0.1),margin 中的ln項是負的,而且負很多,表示它希望把正負樣本的錯誤拉回來;而當樣本分類正確的信心程度越高(tneg=0.1, tpos=0.9),margin中的ln 會越負越少,直到接近0。這時候最終的margin大小就會由m來決定。模型的目標就會變成把正負類的差距拉開到至少2m。
但是為了避免 2m+ln((1−tneg)tpos)<0 造成模型一直停在分錯的狀態,需要選擇一個夠大的m,讓大部分分錯的情況下,2m+ln((1−tneg)tpos)>0 ,這樣模型才有動力往反方向修正。
當然如果太極端的錯誤,如tneg=0.999, tpos=0.001,就可能要選m>3以上,我想m=5應該就很夠用了。
所以最一般的形式應該長這樣:
l=log(∏j∈Ωneg(1+n∑i=1(1−ti,j)esi,j−b+m)∏j∈Ωpos(1+n∑i=1(ti,j)e−si,j+b+m))=log(1+∑j∈Ωneg∑i(1−ti,j)esi,j−b+m+∑j∈Ωpos∑i(ti,j)e−si,j+b+m+(∑j∈Ωneg∑i(1−ti,j)esi,j−b+m)(∑j∈Ωpos∑i(ti,j)e−si,j+b+m)...)...(4.11.2)
取b=0, m=5。
July 18th, 2022
@allenyl|comment-19491
感谢你的详细推导哈,我这里统一回复一下。
1、你的(4.3)是不等于(4.3.1)的,所以后面大体上都有问题~
2、如果只考虑“1个m分类(硬标签)“和”n个2分类(软标签)“两种场景的退化,其实本文的(2)就满足;
3、总的来说,你全程都在类比演绎,但是没考虑结果的合理性,我遇到的主要困难在于如何求出最优解时ti,j与si,j的显式关系。
September 21st, 2022
仔细看式(2)和(8)可以注意到一个inductive bias,就是不同分类任务的第j个类别分数出现在同一个log里面。那么需要想一下为什么“n个m分类”任务的m个标签之间会有对应关系,在这样的情况下所设想的loss在平衡什么关系。
你这个分析没错,所以我就是致力于得到具有像(8)一样理论性质的loss,其最优解是解耦开了n个m分类任务之间的关系的,但很遗憾没找到。