1. 程式人生 > >交叉熵損失函式和均方誤差損失函式

交叉熵損失函式和均方誤差損失函式

交叉熵

分類問題中,預測結果是(或可以轉化成)輸入樣本屬於n個不同分類的對應概率。比如對於一個4分類問題,期望輸出應該為 g0=[0,1,0,0] ,實際輸出為 g1=[0.2,0.4,0.4,0] ,計算g1與g0之間的差異所使用的方法,就是損失函式,分類問題中常用損失函式是交叉熵。

交叉熵(cross entropy)描述的是兩個概率分佈之間的距離,距離越小表示這兩個概率越相近,越大表示兩個概率差異越大。對於兩個概率分佈 p 和 q ,使用 q 來表示 p 的交叉熵為:

由公式可以看出來,p 與 q 之間的交叉熵 和 q 與 p 之間的交叉熵不是等價的。上式表示的物理意義是使用概率分佈 q 來表示概率分佈 p 的困難程式,q 是預測值,p 是期望值。
 
 

神經網路的輸出,也就是前向傳播的輸出可以通過Softmax迴歸變成概率分佈,之後就可以使用交叉熵函式計算損失了。

交叉熵一般會跟Softmax一起使用,在tf中對這兩個函式做了封裝,就是 tf.nn.softmax_cross_entropy_with_logits 函式,可以直接計算神經網路的交叉熵損失。

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y, y_)

其中 y 是網路的輸出,y_ 是期望輸出。

針對分類任務中,正確答案往往只有一個的情況,tf提供了更加高效的 tf.nn.sparse_softmax_cross_entropy_with_logits

函式來求交叉熵損失。

均方誤差

與分類任務對應的是迴歸問題,迴歸問題的任務是預測一個具體的數值,例如雨量預測、股價預測等。迴歸問題的網路輸出一般只有一個節點,這個節點就是預測值。這種情況下就不方便使用交叉熵函式求損失函數了。

迴歸問題中常用的損失函式式均方誤差(MSE,mean squared error),定義如下:

均方誤差的含義是求一個batch中n個樣本的n個輸出與期望輸出的差的平方的平均值。

tf中實現均方誤差的函式為:

mse = tf.reduce_mean(tf.square(y_ - y))

在有些特定場合,需要根據情況自定義損失函式,例如對於非常重要場所的安檢工作,把一個正常物品錯識別為危險品和把一個危險品錯識別為正常品的損失顯然是不一樣的,寧可錯判成危險品,不能漏判一個危險品,所以就要在定義損失函式的時候就要區別對待,對漏判加一個較大的比例係數。在tf中可以通過以下函式自定義:

loss = tf.reduce_sum(tf.select(tf.greater(v1,v2),loss1,loss2))