1. 程式人生 > >【機器學習】交叉熵函式的使用及推導

【機器學習】交叉熵函式的使用及推導

前言

說明:本文只討論Logistic迴歸的交叉熵,對Softmax迴歸的交叉熵類似。 minist手寫數字識別就是用交叉熵作為代價函式。

 

1.從方差代價函式說起

 

代價函式經常用方差代價函式(即採用均方誤差MSE),比如對於一個神經元(單輸入單輸出,sigmoid函式),定義其代價函式為:

其中y是我們期望的輸出,a為神經元的實際輸出【 a=σ(z), where z=wx+b 】。

在訓練神經網路過程中,我們通過梯度下降演算法來更新w和b,因此需要計算代價函式對w和b的導數:

然後更新w、b:

w <—— w - η* ∂C/∂w = w - η * a *σ′(z)

b <—— b - η* ∂C/∂b = b - η * a * σ′(z)

因為sigmoid函式的性質,導致σ′(z)在z取大部分值時會很小(如下圖標出來的兩端,幾近於平坦),這樣會使得w和b更新非常慢(因為η * a * σ′(z)這一項接近於0)。

另外對於邏輯迴歸問題如果使用均方誤差作為代價函式,則會最後優化函式不是凸函式,而是一個震盪函式,會存在很多區域性最小值,很難得到全域性最優。

 

 

交叉熵損失函式

首先,我們二話不說,先放出交叉熵的公式: 

J(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i))),J(θ)=−1m∑i=1my(i)log⁡(hθ(x(i)))+(1−y(i))log⁡(1−hθ(x(i))),


以及J(θ)對J(θ)對引數θθ的偏導數(用於諸如梯度下降法等優化演算法的引數更新),如下: 

∂∂θjJ(θ)=1m∑i=1m(hθ(x(i))−y(i))x(i)j∂∂θjJ(θ)=1m∑i=1m(hθ(x(i))−y(i))xj(i)


但是在大多論文或數教程中,也就是直接給出了上面兩個公式,而未給出推導過程,而且這一過程並不是一兩步就可以得到的,這就給初學者造成了一定的困惑,所以我特意在此詳細介紹了它的推導過程,跟大家分享。

 

我們一共有m組已知樣本,(x(i),y(i))(x(i),y(i))表示第 ii 組資料及其對應的類別標記。其中x(i)=(1,x(i)1,x(i)2,...,x(i)p)Tx(i)=(1,x1(i),x2(i),...,xp(i))T為p+1維向量(考慮偏置項),y(i)y(i)則為表示類別的一個數:

  • logistic迴歸(是非問題)中,y(i)y(i)取0或者1;
  • softmax迴歸(多分類問題)中,y(i)y(i)取1,2…k中的一個表示類別標號的一個數(假設共有k類)。

這裡,只討論logistic迴歸,輸入樣本資料x(i)=(1,x(i)1,x(i)2,...,x(i)p)Tx(i)=(1,x1(i),x2(i),...,xp(i))T,模型的引數為θ=(θ0,θ1,θ2,...,θp)Tθ=(θ0,θ1,θ2,...,θp)T,因此有 

θTx(i):=θ0+θ1x(i)1+⋯+θpx(i)p.θTx(i):=θ0+θ1x1(i)+⋯+θpxp(i).


假設函式(hypothesis function)定義為: 

hθ(x(i))=11+e−θTx(i)hθ(x(i))=11+e−θTx(i)


因為Logistic迴歸問題就是0/1的二分類問題,可以有 

P(y^(i)=1|x(i);θ)=hθ(x(i))P(y^(i)=1|x(i);θ)=hθ(x(i))

P(y^(i)=0|x(i);θ)=1−hθ(x(i))P(y^(i)=0|x(i);θ)=1−hθ(x(i))


現在,我們不考慮“熵”的概念,根據下面的說明,從簡單直觀角度理解,就可以得到我們想要的損失函式:我們將概率取對數,其單調性不變,有

logP(y^(i)=1|x(i);θ)=loghθ(x(i))=log11+e−θTx(i),log⁡P(y^(i)=1|x(i);θ)=log⁡hθ(x(i))=log⁡11+e−θTx(i),

logP(y^(i)=0|x(i);θ)=log(1−hθ(x(i)))=loge−θTx(i)1+e−θTx(i).log⁡P(y^(i)=0|x(i);θ)=log⁡(1−hθ(x(i)))=log⁡e−θTx(i)1+e−θTx(i).


那麼對於第ii組樣本,假設函式表徵正確的組合對數概率為: 

I{y(i)=1}logP(y^(i)=1|x(i);θ)+I{y(i)=0}logP(y^(i)=0|x(i);θ)=y(i)logP(y^(i)=1|x(i);θ)+(1−y(i))logP(y^(i)=0|x(i);θ)=y(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))I{y(i)=1}log⁡P(y^(i)=1|x(i);θ)+I{y(i)=0}log⁡P(y^(i)=0|x(i);θ)=y(i)log⁡P(y^(i)=1|x(i);θ)+(1−y(i))log⁡P(y^(i)=0|x(i);θ)=y(i)log⁡(hθ(x(i)))+(1−y(i))log⁡(1−hθ(x(i)))


其中,I{y(i)=1}I{y(i)=1}I{y(i)=0}I{y(i)=0}為示性函式(indicative function),簡單理解為{ }內條件成立時,取1,否則取0,這裡不贅言。 
那麼對於一共mm組樣本,我們就可以得到模型對於整體訓練樣本的表現能力: 

∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))∑i=1my(i)log⁡(hθ(x(i)))+(1−y(i))log⁡(1−hθ(x(i)))


由以上表徵正確的概率含義可知,我們希望其值越大,模型對資料的表達能力越好。而我們在引數更新或衡量模型優劣時是需要一個能充分反映模型表現誤差的損失函式(Loss function)或者代價函式(Cost function)的,而且我們希望損失函式越小越好。由這兩個矛盾,那麼我們不妨領代價函式為上述組合對數概率的相反數: 

J(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))J(θ)=−1m∑i=1my(i)log⁡(hθ(x(i)))+(1−y(i))log⁡(1−hθ(x(i)))


上式即為大名鼎鼎的交叉熵損失函式。(說明:如果熟悉“資訊熵“的概念E[−logpi]=−∑mi=1pilogpiE[−log⁡pi]=−∑i=1mpilog⁡pi,那麼可以有助理解叉熵損失函式)

 

交叉熵損失函式的求導

這步需要用到一些簡單的對數運算公式,這裡先以編號形式給出,下面推導過程中使用特意說明時都會在該步驟下腳標標出相應的公式編號,以保證推導的連貫性。 
①  logab=loga−logb  log⁡ab=log⁡a−log⁡b 
②  loga+logb=log(ab)  log⁡a+log⁡b=log⁡(ab) 
③  a=logea  a=log⁡ea 
另外,值得一提的是在這裡涉及的求導均為矩陣、向量的導數(矩陣微商),這裡有一篇教程總結得精簡又全面,非常棒,推薦給需要的同學。 
下面開始推導: 
交叉熵損失函式為: 

J(θ)=−1m∑i=1my(i)log(hθ(x(i)))+(1−y(i))log(1−hθ(x(i)))J(θ)=−1m∑i=1my(i)log⁡(hθ(x(i)))+(1−y(i))log⁡(1−hθ(x(i)))


其中, 

loghθ(x(i))=log11+e−θTx(i)=−log(1+e−θTx(i)) ,log(1−hθ(x(i)))=log(1−11+e−θTx(i))=log(e−θTx(i)1+e−θTx(i))=log(e−θTx(i))−log(1+e−θTx(i))=−θTx(i)−log(1+e−θTx(i))①③ .log⁡hθ(x(i))=log⁡11+e−θTx(i)=−log⁡(1+e−θTx(i)) ,log⁡(1−hθ(x(i)))=log⁡(1−11+e−θTx(i))=log⁡(e−θTx(i)1+e−θTx(i))=log⁡(e−θTx(i))−log⁡(1+e−θTx(i))=−θTx(i)−log⁡(1+e−θTx(i))①③ .


由此,得到 

J(θ)=−1m∑i=1m[−y(i)(log(1+e−θTx(i)))+(1−y(i))(−θTx(i)−log(1+e−θTx(i)))]=−1m∑i=1m[y(i)θTx(i)−θTx(i)−log(1+e−θTx(i))]=−1m∑i=1m[y(i)θTx(i)−logeθTx(i)−log(1+e−θTx(i))]③=−1m∑i=1m[y(i)θTx(i)−(logeθTx(i)+log(1+e−θTx(i)))]②=−1m∑i=1m[y(i)θTx(i)−log(1+eθTx(i))]J(θ)=−1m∑i=1m[−y(i)(log⁡(1+e−θTx(i)))+(1−y(i))(−θTx(i)−log⁡(1+e−θTx(i)))]=−1m∑i=1m[y(i)θTx(i)−θTx(i)−log⁡(1+e−θTx(i))]=−1m∑i=1m[y(i)θTx(i)−log⁡eθTx(i)−log⁡(1+e−θTx(i))]③=−1m∑i=1m[y(i)θTx(i)−(log⁡eθTx(i)+log⁡(1+e−θTx(i)))]②=−1m∑i=1m[y(i)θTx(i)−log⁡(1+eθTx(i))]


這次再計算J(θ)J(θ)對第jj個引數分量θjθj求偏導: 

∂∂θjJ(θ)=∂∂θj(1m∑i=1m[log(1+eθTx(i))−y(i)θTx(i)])=1m∑i=1m[∂∂θjlog(1+eθTx(i))−∂∂θj(y(i)θTx(i))]=1m∑i=1m⎛⎝x(i)jeθTx(i)1+eθTx(i)−y(i)x(i)j⎞⎠=1m∑i=1m(hθ(x(i))−y(i))x(i)j∂∂θjJ(θ)=∂∂θj(1m∑i=1m[log⁡(1+eθTx(i))−y(i)θTx(i)])=1m∑i=1m[∂∂θjlog⁡(1+eθTx(i))−∂∂θj(y(i)θTx(i))]=1m∑i=1m(xj(i)eθTx(i)1+eθTx(i)−y(i)xj(i))=1m∑i=1m(hθ(x(i))−y(i))xj(i)


這就是交叉熵對引數的導數: 

∂∂θjJ(θ)=1m∑i=1m(hθ(x(i))−y(i))x(i)j

from:https://blog.csdn.net/jasonzzj/article/details/52017438

       :https://blog.csdn.net/u012162613/article/details/44239919