1. 程式人生 > >交叉熵在loss函式中使用的理解

交叉熵在loss函式中使用的理解

交叉熵(cross entropy)是深度學習中常用的一個概念,一般用來求目標與預測值之間的差距。以前做一些分類問題的時候,沒有過多的注意,直接呼叫現成的庫,用起來也比較方便。最近開始研究起對抗生成網路(GANs),用到了交叉熵,發現自己對交叉熵的理解有些模糊,不夠深入。遂花了幾天的時間從頭梳理了一下相關知識點,才算透徹的理解了,特地記錄下來,以便日後查閱。 資訊理論 交叉熵是資訊理論中的一個概念,要想了解交叉熵的本質,需要先從最基本的概念講起。 1 資訊量 首先是資訊量。假設我們聽到了兩件事,分別如下: 事件A:巴西隊進入了2018世界盃決賽圈。 事件B:中國隊進入了2018世界盃決賽圈。 僅憑直覺來說,顯而易見事件B的資訊量比事件A的資訊量要大。究其原因,是因為事件A發生的概率很大,事件B發生的概率很小。所以當越不可能的事件發生了,我們獲取到的資訊量就越大。越可能發生的事件發生了,我們獲取到的資訊量就越小。那麼資訊量應該和事件發生的概率有關。 假設是一個離散型隨機變數,其取值集合為,概率分佈函式,則定義事件的資訊量為:   由於是概率所以的取值範圍是,繪製為圖形如下:  

 

    可見該函式符合我們對資訊量的直覺 2 熵 考慮另一個問題,對於某個事件,有種可能性,每一種可能性都有一個概率 這樣就可以計算出某一種可能性的資訊量。舉一個例子,假設你拿出了你的電腦,按下開關,會有三種可能性,下表列出了每一種可能的概率及其對應的資訊量
序號 事件 概率p 資訊量I
A 電腦正常開機 0.7 -log(p(A))=0.36
B 電腦無法開機 0.2 -log(p(B))=1.61
C 電腦爆炸了 0.1 -log(p(C))=2.30
注:文中的對數均為自然對數 我們現在有了資訊量的定義,而熵用來表示所有資訊量的期望,即: 其中n代表所有的n種可能性,所以上面的問題結果就是 然而有一類比較特殊的問題,比如投擲硬幣只有兩種可能,字朝上或花朝上。買彩票只有兩種可能,中獎或不中獎。我們稱之為0-1分佈問題(二項分佈的特例),對於這類問題,熵的計算方法可以簡化為如下算式: 3 相對熵(KL散度) 相對熵又稱KL散度,如果我們對於同一個隨機變數 x 有兩個單獨的概率分佈 P(x) 和 Q(x),我們可以使用 KL 散度(Kullback-Leibler (KL) divergence)來衡量這兩個分佈的差異 維基百科對相對熵的定義 In the context of machine learning, DKL(P‖Q) is often called the information gain achieved if P is used instead of Q. 即如果用P來描述目標問題,而不是用Q來描述目標問題,得到的資訊增量。 在機器學習中,P往往用來表示樣本的真實分佈,比如[1,0,0]表示當前樣本屬於第一類。Q用來表示模型所預測的分佈,比如[0.7,0.2,0.1] 直觀的理解就是如果用P來描述樣本,那麼就非常完美。而用Q來描述樣本,雖然可以大致描述,但是不是那麼的完美,資訊量不足,需要額外的一些“資訊增量”才能達到和P一樣完美的描述。如果我們的Q通過反覆訓練,也能完美的描述樣本,那麼就不再需要額外的“資訊增量”,Q等價於P。 KL散度的計算公式:   n為事件的所有可能性。 的值越小,表示q分佈和p分佈越接近 4 交叉熵 對式3.1變形可以得到: 等式的前一部分恰巧就是p的熵,等式的後一部分,就是交叉熵: 在機器學習中,我們需要評估label和predicts之間的差距,使用KL散度剛剛好,即,由於KL散度中的前一部分不變,故在優化過程中,只需要關注交叉熵就可以了。所以一般在機器學習中直接用用交叉熵做loss,評估模型。 機器學習中交叉熵的應用 1 為什麼要用交叉熵做loss函式? 線上性迴歸問題中,常常使用MSE(Mean Squared Error)作為loss函式,比如: 這裡的m表示m個樣本的,loss為m個樣本的loss均值。 MSE線上性迴歸問題中比較好用,那麼在邏輯分類問題中還是如此麼? 2 交叉熵在單分類問題中的使用 這裡的單類別是指,每一張影象樣本只能有一個類別,比如只能是狗或只能是貓。 交叉熵在單分類問題上基本是標配的方法 上式為一張樣本的loss計算方法。式2.1中n代表著n種類別。 舉例說明,比如有如下樣本  

 

  對應的標籤和預測值
* 青蛙 老鼠
Label 0 1 0
Pred 0.3 0.6 0.1
那麼 對應一個batch的loss就是 m為當前batch的樣本數 3 交叉熵在多分類問題中的使用 這裡的多類別是指,每一張影象樣本可以有多個類別,比如同時包含一隻貓和一隻狗 和單分類問題的標籤不同,多分類的標籤是n-hot。 比如下面這張樣本圖,即有青蛙,又有老鼠,所以是一個多分類問題  

 

對應的標籤和預測值
* 青蛙 老鼠
Label 0 1 1
Pred 0.1 0.7 0.8
值得注意的是,這裡的Pred不再是通過softmax計算的了,這裡採用的是sigmoid。將每一個節點的輸出歸一化到[0,1]之間。所有Pred值的和也不再為1。換句話說,就是每一個Label都是獨立分佈的,相互之間沒有影響。所以交叉熵在這裡是單獨對每一個節點進行計算,每一個節點只有兩種可能值,所以是一個二項分佈。前面說過對於二項分佈這種特殊的分佈,熵的計算可以進行簡化。 同樣的,交叉熵的計算也可以簡化,即 注意,上式只是針對一個節點的計算公式。這一點一定要和單分類loss區分開來。 例子中可以計算為: 單張樣本的loss即為 每一個batch的loss就是: 式中m為當前batch中的樣本量,n為類別數。 總結 路漫漫,要學的東西還有很多啊。 參考: https://www.zhihu.com/question/65288314/answer/244557337 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence https://jamesmccaffrey.wordpress.com/2013/11/05/why-you-should-use-cross-entropy-error-instead-of-classification-error-or-mean-squared-error-for-neural-network-classifier-training/