1. 程式人生 > >【深度學習原理】交叉熵損失函式的實現

【深度學習原理】交叉熵損失函式的實現

交叉熵損失函式

一般我們學習交叉熵損失函式是在二元分類情況下:

L = [ y l o g y

^ + ( 1 y ) l o g
( 1 y ^ ) ] L=−[ylog ŷ +(1−y)log (1−ŷ )]

推而廣之,我們可以得到下面這個交叉熵損失函式公式:

E = k t k l o g ( y k ) E=-\sum_k{t_k}log(y_k)

從機器學習的角度看,這裡的 y k y_k 是神經網路的輸出, t k t_k 是正確解的標籤。

而分類標籤有兩種方式:

  • One-Hot編碼
  • 非One-Hot編碼

One-Hot編碼下的損失函式實現

使用One-Hot編碼時, t k t_k 中,只有正確解的標籤才為1,其他的都是0,所以在相乘時,這項就為0,但是我們知道 l o g ( 0 ) log(0) 是負無窮,顯然我們需要特別在程式碼中處理一下。

先不看負無窮的問題,在One-Hot編碼時, t k t_k 中只有為1的這項,才有輸出,也就是說,我們計算交叉熵損失函式,只用計算對應正確解的輸出的自然對數即可

程式碼如下:

def cross_entropy_error(y, t):
	delta = 1e-7
	return -np.sum(t * np.log(y + delta))

比如:

t = [0,0,1,0,0,0,0,0,0,0]
y = [0.1,0.05,0.6,0.0,0.05,0.1,0.0,0.1,0.0,0.0]
cross_entropy_error(y,t) # ==> 0.510825...
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
cross_entropy_error(y,t) # ==> 2.30258... 

第一個案例中,正確的標籤是2,輸出的Softmax概率中2對應的標籤的概率最大,為0.6,由此計算出來的損失函式值為0.51;第二個案例,預測的概率最大為0.1,以第一個作為預測結果,即0是預測值,得出損失函式值為2.3,可見預測錯了損失函式值偏大。

總之,用One-Hot編碼,是將
標籤值和預測值的編碼一一對應,按照交叉熵的公式處理。

非One-hot編碼

如果只有一個值,單個樣本的損失函式計算如下:

def cross_entropy_error(y, t):
	delta = 1e-7
	return -np.log(y + delta)

這是從前面的One-Hot編碼那裡推導來的,我們只需要神經網路在正確標籤處的輸出,就可以計算交叉熵誤差。

如果是Mini-Batch呢,需要做哪些變化?

Mini-Batch下的交叉熵函式

One-Hot編碼

def cross_entropy_error(y, t):
	if y.ndim == 1:
		t = t.reshape(1, t.size) # ndarray的size屬性是存在的
		y = y.reshape(1, y.size)
	batch_size = y.shape[0]
	return -np.sum(t * np.log(y+ 1e-7)) / batch_size

這裡既是y和t是小批量的形式,即二維矩陣,按照Numpy的調性,矩陣直接相乘是按照元素相乘,最後聚和再除以總體個數即可。看起來就除了batch_size,其實是聚和了二維矩陣相乘的結果。

非One-Hot編碼

def cross_entropy_error(y, t):
	if y.ndim == 1:
		t = t.reshape(1, t.size) # ndarray的size屬性是存在的
		y = y.reshape(1, y.size)
	batch_size = y.shape[0]
	return -np.sum(np.log(y[np.arange(batch_size),t] + 1e-7)) / batch_size

這裡還是需要注意這句話:我們只需要神經網路在正確標籤處的輸出,就可以計算交叉熵誤差。所以看起來很複雜的y[np.arange(batch_size),t]目的也是為了獲得神經網路的輸出,取出的是多行與多列的組合。

END.

參考:
《深度學習入門:基於Python的理論和實現》

https://jamesmccaffrey.wordpress.com/2013/11/05/why-you-should-use-cross-entropy-error-instead-of-classification-error-or-mean-squared-error-for-neural-network-classifier-training/

https://www.jianshu.com/p/474439106874