1. 程式人生 > >一文搞懂交叉熵在機器學習中的使用,透徹理解交叉熵背後的直覺

一文搞懂交叉熵在機器學習中的使用,透徹理解交叉熵背後的直覺

關於交叉熵在loss函式中使用的理解

交叉熵(cross entropy)是深度學習中常用的一個概念,一般用來求目標與預測值之間的差距。以前做一些分類問題的時候,沒有過多的注意,直接呼叫現成的庫,用起來也比較方便。最近開始研究起對抗生成網路(GANs),用到了交叉熵,發現自己對交叉熵的理解有些模糊,不夠深入。遂花了幾天的時間從頭梳理了一下相關知識點,才算透徹的理解了,特地記錄下來,以便日後查閱。

資訊理論

交叉熵是資訊理論中的一個概念,要想了解交叉熵的本質,需要先從最基本的概念講起。

1 資訊量

首先是資訊量。假設我們聽到了兩件事,分別如下: 事件A:巴西隊進入了2018世界盃決賽圈。 事件B:中國隊進入了2018世界盃決賽圈。 僅憑直覺來說,顯而易見事件B的資訊量比事件A的資訊量要大。究其原因,是因為事件A發生的概率很大,事件B發生的概率很小。所以當越不可能的事件發生了,我們獲取到的資訊量就越大。越可能發生的事件發生了,我們獲取到的資訊量就越小。那麼資訊量應該和事件發生的概率有關。

假設XX的資訊量為:

I(x0)=log(p(x0))I(x0)=−log(p(x0)),繪製為圖形如下:
這裡寫圖片描述 可見該函式符合我們對資訊量的直覺

2 熵

考慮另一個問題,對於某個事件,有nn 這樣就可以計算出某一種可能性的資訊量。舉一個例子,假設你拿出了你的電腦,按下開關,會有三種可能性,下表列出了每一種可能的概率及其對應的資訊量

序號 事件 概率p 資訊量I
A 電腦正常開機 0.7 -log(p(A))=0.36
B 電腦無法開機 0.2 -log(p(B))=1.61
C 電腦爆炸了 0.1 -log(p(C))=2.30

注:文中的對數均為自然對數

我們現在有了資訊量的定義,而熵用來表示所有資訊量的期望,即:

H(X)=i=1np(xi)log(p(xi))H(X)=−∑i=1np(xi)log(p(xi))

其中n代表所有的n種可能性,所以上面的問題結果就是

H(X)===[p(A)log(p(A))+p(B)log(p(B))+p(C))log(p(C))]0.7×0.36+0.2×1.61+0.1×2.300.804H(X)=−[p(A)log(p(A))+p(B)log(p(B))+p(C))log(p(C))]=0.7×0.36+0.2×1.61+0.1×2.30=0.804

然而有一類比較特殊的問題,比如投擲硬幣只有兩種可能,字朝上或花朝上。買彩票只有兩種可能,中獎或不中獎。我們稱之為0-1分佈問題(二項分佈的特例),對於這類問題,熵的計算方法可以簡化為如下算式:

H(X)==i=1np(xi)log(p(xi))p(x)log(p(x))(1p(x))log(1p(x))H(X)=−∑i=1np(xi)log(p(xi))=−p(x)log(p(x))−(1−p(x))log(1−p(x))

3 相對熵(KL散度)

相對熵又稱KL散度,如果我們對於同一個隨機變數 x 有兩個單獨的概率分佈 P(x) 和 Q(x),我們可以使用 KL 散度(Kullback-Leibler (KL) divergence)來衡量這兩個分佈的差異

維基百科對相對熵的定義

In the context of machine learning, DKL(P‖Q) is often called the information gain achieved if P is used instead of Q.

即如果用P來描述目標問題,而不是用Q來描述目標問題,得到的資訊增量。

在機器學習中,P往往用來表示樣本的真實分佈,比如[1,0,0]表示當前樣本屬於第一類。Q用來表示模型所預測的分佈,比如[0.7,0.2,0.1] 直觀的理解就是如果用P來描述樣本,那麼就非常完美。而用Q來描述樣本,雖然可以大致描述,但是不是那麼的完美,資訊量不足,需要額外的一些“資訊增量”才能達到和P一樣完美的描述。如果我們的Q通過反覆訓練,也能完美的描述樣本,那麼就不再需要額外的“資訊增量”,Q等價於P。

KL散度的計算公式:

DKL(p||q)=i=1np(xi)log(p(xi)q(xi))(3.1)(3.1)DKL(p||q)=∑i=1np(xi)log(p(xi)q(xi))的值越小,表示q分佈和p分佈越接近

4 交叉熵

對式3.1變形可以得到:

DKL(p||q)==i=1np(xi)log(p(xi))i=1np(xi)log(q(xi))H(p(x))+[i=1np(xi)log(q(xi))]DKL(p||q)=∑i=1np(xi)log(p(xi))−∑i=1np(xi)log(q(xi))=−H(p(x))+[−∑i=1np(xi)log(q(xi))]

等式的前一部分恰巧就是p的熵,等式的後一部分,就是交叉熵:

H(p,q)=i=1np(xi)log(q(xi))H(p,q)=−∑i=1np(xi)log(q(xi))

在機器學習中,我們需要評估label和predicts之間的差距,使用KL散度剛剛好,即DKL(y||y^)DKL(y||y^)不變,故在優化過程中,只需要關注交叉熵就可以了。所以一般在機器學習中直接用用交叉熵做loss,評估模型。

機器學習中交叉熵的應用

1 為什麼要用交叉熵做loss函式?

線上性迴歸問題中,常常使用MSE(Mean Squared Error)作為loss函式,比如:

loss=12mi=1m(yiyi^)2loss=12m∑i=1m(yi−yi^)2

這裡的m表示m個樣本的,loss為m個樣本的loss均值。 MSE線上性迴歸問題中比較好用,那麼在邏輯分類問題中還是如此麼?

2 交叉熵在單分類問題中的使用

這裡的單類別是指,每一張影象樣本只能有一個類別,比如只能是狗或只能是貓。 交叉熵在單分類問題上基本是標配的方法

loss=i=1nyilog(yi^)(2.1)(2.1)loss=−∑i=1nyilog(yi^)

上式為一張樣本的loss計算方法。式2.1中n代表著n種類別。 舉例說明,比如有如下樣本

這裡寫圖片描述

對應的標籤和預測值

* 青蛙 老鼠
Label 0 1 0
Pred 0.3 0.6 0.1

那麼

loss==(0×log(0.3)+1×log(0.6)+0×log(0.1)log(0.6)loss=−(0×log(0.3)+1×log(0.6)+0×log(0.1)=−log(0.6)

對應一個batch的loss就是

loss=1mj=1mi=1nyjilog(yji^)loss=−1m∑j=1m∑i=1nyjilog(yji^)

m為當前batch的樣本數

3 交叉熵在多分類問題中的使用

這裡的多類別是指,每一張影象樣本可以有多個類別,比如同時包含一隻貓和一隻狗 和單分類問題的標籤不同,多分類的標籤是n-hot。 比如下面這張樣本圖,即有青蛙,又有老鼠,所以是一個多分類問題

這裡寫圖片描述

對應的標籤和預測值

* 青蛙 老鼠
Label 0 1 1
Pred 0.1 0.7 0.8

值得注意的是,這裡的Pred不再是通過softmax計算的了,這裡採用的是sigmoid。將每一個節點的輸出歸一化到[0,1]之間。所有Pred值的和也不再為1。換句話說,就是每一個Label都是獨立分佈的,相互之間沒有影響。所以交叉熵在這裡是單獨對每一個節點進行計算,每一個節點只有兩種可能值,所以是一個二項分佈。前面說過對於二項分佈這種特殊的分佈,熵的計算可以進行簡化。

同樣的,交叉熵的計算也可以簡化,即

loss=ylog(y^)(1y)log(1y^)loss=−ylog(y^)−(1−y)log(1−y^)

注意,上式只是針對一個節點的計算公式。這一點一定要和單分類loss區分開來。 例子中可以計算為:

losslossloss===0×log(0.1)(10)log(10.1)=log(0.9)1×log(0.7)(11)log(10.7)=log(0.7)1×log(0.8)(11)log(10.8)=log(0.8)loss貓=−0×log(0.1)−(1−0)log(1−0.1)=−log(0.9)loss蛙=−1×log(0.7)−(1−1)log(1−0.7)=−log(0.7)loss鼠=−1×log(0.8)−(1−1)log(1−0.8)=−log(0.8)

單張樣本的loss即為loss=loss+loss+lossloss=loss貓+loss蛙+loss鼠

式中m為當前batch中的樣本量,n為類別數。

總結

路漫漫,要學的東西還有很多啊。

參考: