1. 程式人生 > >神經網路多分類任務的損失函式——交叉熵

神經網路多分類任務的損失函式——交叉熵

神經網路解決多分類問題最常用的方法是設定n個輸出節點,其中n為類別的個數。對於每一個樣例,神經網路可以得到的一個n維陣列作為輸出結果。陣列中的每一個維度(也就是每一個輸出節點)對應一個類別。在理想情況下,如果一個樣本屬於類別k,那麼這個類別所對應的輸出節點的輸出值應該為1,而其他節點的輸出都為0。

以識別手寫數字為例,0~9共十個類別。識別數字1,神經網路的輸出結果越接近[0,1,0,0,0,0,0,0,0,0]越好。交叉上是最好的評判方法之一。交叉熵刻畫了兩個概率分佈之間的距離,它是分類問題中使用比較廣的一種損失函式。


p代表正確答案,q代表的是預測值。交叉熵值越小,兩個概率分佈越接近。

需要注意的是,交叉熵刻畫的是兩個概率分佈之間的距離,然而神經網路的輸出卻不一定是一個概率分佈,很多情況下是實數。如何將神經網路前向傳播得到的結果也變成概率分佈,Softmax迴歸就是一個非常有用的方法。


Softmax將神經網路的輸出變成了一個概率分佈,這個新的輸出可以理解為經過神經網路的推導,一個樣例為不同類別的概率分別是多大。這樣就把神經網路的輸出也變成了一個概率分佈,從而可以通過交叉熵來計算預測的概率分佈和真實答案的概率分佈之間的距離了。

例子:

假設有一個三分類問題,某個樣例的正確答案是(1,0,0)。某模型經過Softmax迴歸之後的預測答案是(0.5,0,4,0.1),那麼這個預測和正確答案直接的交叉熵是:


如果另外一個模型的預測是(0.8,0.1,0.1),那麼這個預測值和真實值的交叉熵是:


從直觀上可以很容易知道第二個答案要優於第二個。通過交叉熵計算得到的結果也是一致的(第二個交叉熵的值更小)。