1. 程式人生 > >為什麼用交叉熵作為損失函式

為什麼用交叉熵作為損失函式

交叉熵(cross entropy)經常用來做機器學習中的損失函式。
要講交叉熵就要從最基本的資訊熵說起。

1.資訊熵

資訊熵是消除不確定性所需資訊量的度量。(多看幾遍這句話)

資訊熵就是資訊的不確定程度,資訊熵越小,資訊越確定。

=x=1n(x×x)資訊熵 = \sum_{x=1}^{n}(資訊x發生的概率 × 驗證資訊x需要的資訊量)

(因為事件都有個概率分佈,這裡我們只考慮離散分佈)

舉個列子,比如說:今年中國取消高考了,這句話我們很不確定(甚至心裡還覺得這TM是扯淡),那我們就要去查證了,這樣就需要很多資訊量(去查證);反之如果說今年正常高考,大家回想:這很正常啊,不怎麼需要查證,這樣需要的資訊量就很小。從這裡我們可以學到:根據資訊的真實分佈

,我們能夠找到一個最優策略,以最小的代價消除系統的不確定性,即最小資訊熵

簡而言之,概率越低,需要越多的資訊去驗證,所以驗證真假需要的資訊量和概率成反比。我們需要用數學表示式把它描述出來,推導:

考慮一個離散的隨機變數 xx,已知資訊的量度依賴於概率分佈 p(x)p(x),因此我們想要尋找一個函式 I(x)I(x),它是概率p(x)p(x)的單調函式,表達了資訊的內容。
怎麼尋找呢?如果我們有兩個不相關的事件 xxyy,那麼觀察兩個事件同時發生時獲得的資訊量應該等於觀察到事件各自發生時獲得的資訊之和,即:
I(x,y)=I(x)+I(y)I(x,y)=I(x)+I(y)

因為兩個事件是獨立不相關的,因此
p(x,y)=p(x)p(y)p(x,y)=p(x)p(y)

根據這兩個關係,很容易看出 I(x)I(x)一定與 P(x)P(x) 的對數有關
因為對數的運演算法則是
loga(mn)=logam+loganlog_a(mn)=log_am+log_an

因此,我們有
I(x)=log(p(x))I(x)=−log(p(x))

其中負號是用來保證資訊量是正數或者零。而 log 函式基的選擇是任意的(資訊理論中基常常選擇為2,因此資訊的單位為位元bits;而機器學習中基常常選擇為自然常數,因此單位常常被稱為奈特nats)。I

(x)I(x) 也被稱為隨機變數 x 的自資訊 (self-information),描述的是隨機變數的某個事件發生所帶來的資訊量。

以上推導引用自https://www.cnblogs.com/kyrieng/p/8694705.html

資訊熵即所有資訊量的期望:
H(X)=xp(x)log(p(x))=i=1np(xi)log(p(xi))H(X)=−∑_xp(x)log(p(x))=−∑_{i=1}^np(x_i)log(p(x_i))

注:其中n為事件的所有可能性。

2.相對熵(KL散度)

相對熵又稱KL散度,如果對於同一個隨機變數xx有兩個單獨的概率分佈p(x)p(x)q(x)q(x),可以使用相對熵來衡量這兩個分佈的差異。
DKL(pq)=i=1np(xi)log(p(xi)q(xi))D_{KL}(p||q)=\sum_{i=1}^np(x_i)log(\frac{p(x_i)}{q(x_i)})

注:DKLD_{KL}越小,表示p(x)和q(x)的分佈越近。

3.交叉熵

交叉熵公式:
H(p,q)=i=1np(xi)log(q(xi))H(p,q)=-\sum_{i=1}^np(x_i)log(q(x_i))

相對熵的推導:
DKL(pq)=i=1np(xi)log(p(xi))i=1np(xi)log(q(xi))=[i=1np(xi)log(q(xi))]H(p(x)) \begin{array}{l} \quad D_{KL}(p||q) \\\\ = \sum_{i=1}^np(x_i)log(p(x_i))-\sum_{i=1}^np(x_i)log(q(x_i)) \\\\ = [-\sum_{i=1}^np(x_i)log(q(x_i))]-H(p(x))\\ \end{array}

在機器學習中,往往用p(x)用來描述真實分佈,q(x)用來描述模型預測的分佈。

計算損失,理應使用相對熵來計算概率分佈的差異,然而由相對熵推匯出的結果看:

=相對熵=交叉熵-資訊熵

由於資訊熵描述的是消除(p,即真實分佈)的不確定性所需資訊量的度量,所以其值應該是最小的、固定的。那麼:優化減小相對熵也就是優化交叉熵,所以在機器學習中使用交叉熵就可以了。

4.為什麼使用交叉熵

在機器學習中,我們希望模型在訓練資料上學到的預測資料分佈真實資料分佈越相近越好,上面講過了,用相對熵,但是為了簡便計算使用交叉熵就可以了。

注意此處真實資料分佈指的就是訓練資料的分佈(標註)。

交叉熵損失函式:

L=[ylogy^+(1y)log(1y^)]L=-[ylog\ \hat y+(1-y)log\ (1-\hat y)]

交叉熵損失函式一般用來代替均方差損失函式與sigmoid啟用函式組合。
sigmoid啟用函式表示式:
σ(z)=11+ez\sigma(z) = \frac{1}{1+e^{-z}}

從圖中可以看出,對於sigmoid,當xx的取值越大或越小,函式曲線變得越平緩,意味著導數σ(x)σ′(x)越趨近於0。

以單個樣本的一次梯度下降為例:

z=wx+bz= wx+b

y^=a=σ(z)\hat{y}= a =\sigma(z)

L1(y,a)=12(ya)2L_1(y,a)=\frac{1}{2}(y-a)^2

L2(y,a)=(ylog(a)+(1y)log(1a))L_2(y,a)=-(ylog(a)+(1-y)log(1-a))

前兩個公式公式分別是前向傳播的線性和非線性部分,第三個公式公式是均方差損失函式,第四個公式是交叉熵損失函式。梯度下降的目的,直白地說:是減小真實值和預測值的距離,而損失函式用來度量真實值和預測值之間距離,所以梯度下降目的也就是減小損失函式的值。怎麼減小損失函式的值呢?變數只有wwbb,所以我們要做的就是不斷修改wwbb的值以使損失函式越來越小。(這裡例子只有一步,只修改一次)

wwbb的更新:=×引數=引數-學習率×損失函式對引數的偏導

w=wαL(y,a)ww = w - \alpha \frac{\partial L(y,a)}{\partial w}

b=bαL(y,a)wb = b - \alpha \frac{\partial L(y,a)}{\partial w}

其中α\alpha 表示學習率,用來控制步長,即向下走一步的長度

為什麼要這樣更新引數呢,講完下面的關鍵點我們會解釋一下。

關鍵點來了,為什麼用交叉熵而不是均方差呢?

均方差對引數的偏導:

L1(y,a)w=yσ(z)σ(z)x\frac{\partial L_1(y,a)}{\partial w}=-|y-\sigma(z)|\sigma'(z)x

L1(y,a)b=yσ(z)σ(z)\frac{\partial L_1(y,a)}{\partial b}=-|y-\sigma(z)|\sigma'(z)