1. 程式人生 > >[ch03-02] 交叉熵損失函式

[ch03-02] 交叉熵損失函式

系列部落格,原文在筆者所維護的github上:https://aka.ms/beginnerAI,
點選star加星不要吝嗇,星越多筆者越努力。

3.2 交叉熵損失函式

交叉熵(Cross Entropy)是Shannon資訊理論中一個重要概念,主要用於度量兩個概率分佈間的差異性資訊。在資訊理論中,交叉熵是表示兩個概率分佈 \(p,q\) 的差異,其中 \(p\) 表示真實分佈,\(q\) 表示非真實分佈,那麼\(H(p,q)\)就稱為交叉熵:

\[H(p,q)=\sum_i p_i \cdot \ln {1 \over q_i} = - \sum_i p_i \ln q_i \tag{1}\]

交叉熵可在神經網路中作為損失函式,\(p\) 表示真實標記的分佈,\(q\) 則為訓練後的模型的預測標記分佈,交叉熵損失函式可以衡量 \(p\) 與 \(q\) 的相似性。

交叉熵函式常用於邏輯迴歸(logistic regression),也就是分類(classification)。

3.2.1 交叉熵的由來

資訊量

資訊理論中,資訊量的表示方式:

\[I(x_j) = -\ln (p(x_j)) \tag{2}\]

\(x_j\):表示一個事件

\(p(x_j)\):表示\(x_j\)發生的概率

\(I(x_j)\):資訊量,\(x_j\)越不可能發生時,它一旦發生後的資訊量就越大

假設對於學習神經網路原理課程,我們有三種可能的情況發生,如表3-2所示。

表3-2 三種事件的概論和資訊量

事件編號 事件 概率 \(p\) 資訊量 \(I\)
\(x_1\) 優秀 \(p=0.7\) \(I=-\ln(0.7)=0.36\)
\(x_2\) 及格 \(p=0.2\) \(I=-\ln(0.2)=1.61\)
\(x_3\) 不及格 \(p=0.1\) \(I=-\ln(0.1)=2.30\)

WoW,某某同學不及格!好大的資訊量!相比較來說,“優秀”事件的資訊量反而小了很多。

\[H(p) = - \sum_j^n p(x_j) \ln (p(x_j)) \tag{3}\]

則上面的問題的熵是:

\[ \begin{aligned} H(p)&=-[p(x_1) \ln p(x_1) + p(x_2) \ln p(x_2) + p(x_3) \ln p(x_3)] \\ &=0.7 \times 0.36 + 0.2 \times 1.61 + 0.1 \times 2.30 \\ &=0.804 \end{aligned} \]

相對熵(KL散度)

相對熵又稱KL散度,如果我們對於同一個隨機變數 \(x\) 有兩個單獨的概率分佈 \(P(x)\) 和 \(Q(x)\),我們可以使用 KL 散度(Kullback-Leibler (KL) divergence)來衡量這兩個分佈的差異,這個相當於資訊理論範疇的均方差。

KL散度的計算公式:

\[D_{KL}(p||q)=\sum_{j=1}^n p(x_j) \ln{p(x_j) \over q(x_j)} \tag{4}\]

\(n\) 為事件的所有可能性。\(D\) 的值越小,表示 \(q\) 分佈和 \(p\) 分佈越接近。

交叉熵

把上述公式變形:

\[ \begin{aligned} D_{KL}(p||q)&=\sum_{j=1}^n p(x_j) \ln{p(x_j)} - \sum_{j=1}^n p(x_j) \ln q(x_j) \\ &=- H(p(x)) + H(p,q) \end{aligned} \tag{5} \]

等式的前一部分恰巧就是p的熵,等式的後一部分,就是交叉熵:

\[H(p,q) =- \sum_{j=1}^n p(x_j) \ln q(x_j) \tag{6}\]

在機器學習中,我們需要評估label和predicts之間的差距,使用KL散度剛剛好,即\(D_{KL}(y||a)\),由於KL散度中的前一部分\(H(y)\)不變,故在優化過程中,只需要關注交叉熵就可以了。所以一般在機器學習中直接用交叉熵做損失函式來評估模型。

\[loss =- \sum_{j=1}^n y_j \ln a_j \tag{7}\]

其中,\(n\) 並不是樣本個數,而是分類個數。所以,對於批量樣本的交叉熵計算公式是:

\[J =- \sum_{i=1}^m \sum_{j=1}^n y_{ij} \ln a_{ij} \tag{8}\]

\(m\) 是樣本數,\(n\) 是分類數。

有一類特殊問題,就是事件只有兩種情況發生的可能,比如“學會了”和“沒學會”,稱為\(0/1\)分佈或二分類。對於這類問題,由於\(n=2\),所以交叉熵可以簡化為:

\[loss =-[y \ln a + (1-y) \ln (1-a)] \tag{9}\]

二分類對於批量樣本的交叉熵計算公式是:

\[J= - \sum_{i=1}^m [y_i \ln a_i + (1-y_i) \ln (1-a_i)] \tag{10}\]

3.2.2 二分類問題交叉熵

把公式10分解開兩種情況,當\(y=1\)時,即標籤值是1,是個正例,加號後面的項為0:

\[loss = -\ln(a) \tag{11}\]

橫座標是預測輸出,縱座標是損失函式值。y=1意味著當前樣本標籤值是1,當預測輸出越接近1時,損失函式值越小,訓練結果越準確。當預測輸出越接近0時,損失函式值越大,訓練結果越糟糕。

當y=0時,即標籤值是0,是個反例,加號前面的項為0:

\[loss = -\ln (1-a) \tag{12}\]

此時,損失函式值如圖3-10。

圖3-10 二分類交叉熵損失函式圖

假設學會了課程的標籤值為1,沒有學會的標籤值為0。我們想建立一個預測器,對於一個特定的學員,根據出勤率、課堂表現、作業情況、學習能力等等來預測其學會課程的概率。

對於學員甲,預測其學會的概率為0.6,而實際上該學員通過了考試,真實值為1。所以,學員甲的交叉熵損失函式值是:

\[ loss_1 = -(1 \times \ln 0.6 + (1-1) \times \ln (1-0.6)) = 0.51 \]

對於學員乙,預測其學會的概率為0.7,而實際上該學員也通過了考試。所以,學員乙的交叉熵損失函式值是:

\[ loss_2 = -(1 \times \ln 0.7 + (1-1) \times \ln (1-0.7)) = 0.36 \]

由於0.7比0.6更接近1,是相對準確的值,所以 \(loss2\) 要比 \(loss1\) 小,反向傳播的力度也會小。

3.2.3 多分類問題交叉熵

當標籤值不是非0即1的情況時,就是多分類了。假設期末考試有三種情況:

  1. 優秀,標籤值OneHot編碼為\([1,0,0]\)
  2. 及格,標籤值OneHot編碼為\([0,1,0]\)
  3. 不及格,標籤值OneHot編碼為\([0,0,1]\)

假設我們預測學員丙的成績為優秀、及格、不及格的概率為:\([0.2,0.5,0.3]\),而真實情況是該學員不及格,則得到的交叉熵是:

\[ loss_1 = -(0 \times \ln 0.2 + 0 \times \ln 0.5 + 1 \times \ln 0.3) = 1.2 \]

假設我們預測學員丁的成績為優秀、及格、不及格的概率為:\([0.2,0.2,0.6]\),而真實情況是該學員不及格,則得到的交叉熵是:

\[ loss_2 = -(0 \times \ln 0.2 + 0 \times \ln 0.2 + 1 \times \ln 0.6) = 0.51 \]

可以看到,0.51比1.2的損失值小很多,這說明預測值越接近真實標籤值(0.6 vs 0.3),交叉熵損失函式值越小,反向傳播的力度越小。

3.2.4 為什麼不能使用均方差做為分類問題的損失函式?

  1. 迴歸問題通常用均方差損失函式,可以保證損失函式是個凸函式,即可以得到最優解。而分類問題如果用均方差的話,損失函式的表現不是凸函式,就很難得到最優解。而交叉熵函式可以保證區間內單調。

  2. 分類問題的最後一層網路,需要分類函式,Sigmoid或者Softmax,如果再接均方差函式的話,其求導結果複雜,運算量比較大。用交叉熵函式的話,可以得到比較簡單的計算結果,一個簡單的減法就可以得到反向誤差。