1. 程式人生 > >"兩步走方法"解析損失函式:mean square error,cross entropy,softmax,SVM

"兩步走方法"解析損失函式:mean square error,cross entropy,softmax,SVM

本文試圖將四類最基礎也最常用的損失函式進行統一歸納,以明白各自的區別,加深對他們的記憶和理解。

本文會從兩個步驟分析損失函式,第一個是從輸入到模型的輸出,第二個是從模型的輸出到損失函式的計算。

第一個步奏,從輸入到模型的輸出。我先解釋一下什麼是模型的輸出。比如當我們訓練好了一個迴歸模型,用來判斷房子的大小和價格的關係。那麼,當我們輸入一個房子的大小時,模型就會輸出一個房子的價格,而這房子的價格就是模型的輸出。簡單來說,如果用x表示輸入的話,f(x)可以用來表示模型的輸出。

建立在第一個問題上,我們有了損失函式的概念。損失函式就是來衡量模型的輸出與真實值之間的差距。我們試圖減小這個差距以達到訓練模型的目的。因此,可以認為損失函式是依附於模型輸入與輸出之上的。

輸入,模型輸出,損失函式的三者關係

我們可以用上圖來表示輸入,模型輸出,真實值和損失函式之間的關係。

為了更加方便的說明下面的內容,我們還需要注意到樣本的概念。這裡不試圖說明樣本概念,只是想提及一個輸入和多個輸入的區別。

多個輸入多個輸出

我們在訓練模型時,常常是對很多輸入和輸出來進行的。我們必須要明確多個輸入和多個輸出的基本規律,而我們求損失函式也是基於多個輸出的。實際上,更進一步,我們求導亦是如此。

還要注意的是,輸入的數量和輸入的維度是不一樣的概念,不要混淆。舉例來說,還是使用房價的迴歸來說,我們使用房子的大小和房子距離最近的地鐵之間的距離來擬合。這時候,每一個輸入xi應該包含兩個維度,一個是房子面積,一個是房子與某個地鐵的距離。而我們有很多個房子的價格,每一個房子資料則為一個個樣本。如下圖,這裡有三個樣本,x1,x2,x3,每個樣本則有兩個維度。

樣本數量與輸入維度

綜合以上,並進一步說明以下符號的表示:

我們用x表示輸入,xij(ij為右下角標)表示不同的輸入。其中,i表示樣本的標號,j表示維度標號。

f(xi)表示輸入為xi時的模型輸出。

用L表示所有輸入(不管是多少個輸入)的總損失,其中Li(i為右下角標)表示輸入為xi時的損失。

g(xi)表示真實值,g可以理解為groundtruth。

N表示樣本數量。

M表示輸入的維度。

用W表示引數。w1表示對應於輸入的第一個維度的引數。

還有其他符號會在給出時說明。

1.Mean Square Error

該損失函式可以認為是最基礎的損失函式,最易理解。在機器學習中可以簡寫為MSE,如tensorflow中可以用MSE來說明你使用的損失函式。

中文可以翻譯成均方誤差。其中,“均”是指平均,如下公式中的N分之一和求和符號,就是對多個輸入的損失值求和之後求平均;“方”表示對某個數的平方,如下面公式中的平方;而誤差則可以理解為兩個數之間的差,可以理解為下面公式中的減號,推廣開來可以理解為兩個向量之間的差。公式可以是:

Mean Square Error

首先,這個公式並不限制f(x)是怎麼得到的。而且直接將模型輸入使用在損失函式中。其中2分之1,是為了方便求導計算。

補充說明一下,通常f(x)可以用公式來表示,下面會用其他方式來表示不同損失函式的計算:

公式(1)- 模型的輸出

從這個輸入到輸出的函式來看,這是一個線性關係,最為簡單的關係。

大家也都對這個函式很熟悉,這裡就不過多介紹。關鍵是下面的幾個損失函式。

2.cross entropy error

可翻譯為交叉熵損失,與資訊理論中熵的概念也是有關係,這裡就不展開了。

在第一步,也就是從輸入到模型的輸出,交叉熵損失的模型輸出與均方誤差已經有所不同。上面說到,均方誤差是一個最簡單的線性關係。而對於交叉熵損失來說,需要一個非線性的對映。

在得到f(xi)之後,交叉熵損失函式將其進一步作為輸入輸進sigmoid的函式中,我們用S(f(xi))來表示:

sigmoid函式

sigmoid函式影象如下圖,這是一個輸出值在0到1之間的對映關係,屬於非線性對映。

sigmoid函式影象

總的來說,對於交叉熵損失函式從輸入到輸出與均方誤差不同的地方在於,在其得到線性對映之後又加了一個sigmoid的非線性對映

而對於該模型來說,s值就是模型的輸出。

以下分析得到模型的輸出之後,交叉熵損失如何進一步計算損失的。

如第一節所說,均方誤差是用真實值與輸出值之間的絕對誤差的平方來表示的。

由於各種原因,我們在使用交叉熵損失函式的時候,真實值只有兩種情況,一種是0,一種是1。而交叉熵損失中的模型輸出在0到1之間。我們使用log函式來表示真實值與模型輸出之間的關係。而真實值有兩個,我們可以分開討論。

交叉熵的損失函式

這裡,我使用了ln 來 代替 log,這樣可以更方便的求導。我們可以接著繼續觀察一下ln函式的影象。

-ln(x)函式

這裡,我們可以看到當輸入在0到1之間時,-ln(x)的值域在0到正無窮之間。

損失函式的價值可以分析一下,當真實值為1時,如果模型輸出越接近1,則損失值越接近於0;如果模型輸出越接近0,則損失值則接近於正無窮。當真實值為0時,如果模型越接近0,則損失值也接近於0,反之亦然。這裡可以認為起到了衡量真實值與模型輸出之間的作用。

為了統一一下上面的兩個真實值的公式,我們有下面這個公式(我們使用g表示g(xi),使用f表示f(xi)):

統一格式的交叉熵損失

最後,我們再次總結一下,交叉熵損失函式從模型輸出到損失函式,使用了對數函式,其特點是真實值要麼是0要麼是1。損失函式可以讓越接近真實值的損失值接近於0,遠離真實值的損失值趨向於無窮大。其常用於分類問題中。

3.softmax

交叉熵損失可以解決二分類問題,但是卻無法一次性針對多分類問題。這時,我們可以使用softmax進行分類。

同均方誤差函式一樣,我們得到了f(xi)之後,我們進一步處理。可以描述為求e次方,然後再歸一化。得到的一個所有分類總和為1的數,我們用概率來描述這樣的方式,實際上,每個數也表示了對應的分類概率,我們用p來作為記號(這裡我們用fi來表示f(xi)):

softmax的模型輸出

這裡只表示一個樣本的情況,g=k表示真實情況為k類,X=xi表示輸入為xi,j就是一個樣本分類的數量。以上就是模型的輸出。

而損失函式則是描述真實值和模型輸出之間的差距。我們計算好了模型輸出之後,我們就可以計算損失函式:

softmax的損失函式

前面第二節已經說到,ln函式的影象關係。我們知道當某一分類的模型輸出接近於1的時候,損失值是接近0的。

但是,這裡我們貌似只看到了模型的輸出,而並沒有見到真實值。實際情況並非如此,我們對應的損失函式是某樣本下真實值對應的模型輸出值的ln值。如果這個值越接近於1,說明我們模型認為這個樣本應該就是該類。為了更加形象的說明,下面舉出一個實際的例子,例子中只有一個樣本,一個樣本對應的分類可能有三個,分別為貓,汽車,和人,我們直接從得到了f開始計算:

一個計算softmax的例子

這裡關鍵是要理解損失函式的值與真實值的關係。真實值由於是正確的分類,並沒有直接參與計算,而是其對應標籤下的概率值參與計算。

可以知道,softmax也是一個非線性的對映,處理的是多分類問題,輸出是每個類的概率值,損失函式是正確分類的對數函式值。

4.SVM

還是針對多分類問題,我們在得到了f值之後,進一步進行計算。我們知道模型的輸出是我們對樣本的分類,而實際上SVM的輸出是與核相關的,這也是其獨特之處。這裡我並不是很想講核函式是什麼,與輸出又有什麼關係。我們直接認為SVM的輸出就是均方誤差裡的輸出,也就是上面一直提到的f(xi),當然,這還不是最後的輸出,SVM的輸出就是每個樣本的在每個分類下最大值對應的類。

下面我們講SVM的損失怎麼由f(x)計算得到。由於是多分類問題,而我們又使用當前樣本在每個分類下最大值對應的類作為我們的預測值,我們可以假設如果正確分類下的值是最大值,則我們的預測是正確的。那麼如何衡量這個值了,可以使用差來表示。而又有多個分類,我們進一步可以使用差值的和來進行衡量。

而SVM一個特性就是,只有當真實標籤下對應的值比其他值大上一個值才認為是零損失。個人稱這個值是保險值。整體公式如下:

SVM損失函式

公式中1就是上面說到的那個值。累加符號表示樣本的每一個分類(除了正確的那個分類)。

舉個例子,同上面的那個例子,還是從得到了f值開始:

SVM損失的一個例子

另外提一句,我們說SVM其實主要分為兩個部分,第一個部分是通過輸入求模型的輸出,這裡面核函式起到重要作用,這裡我試圖用簡單的話來解釋核函式有什麼作用,它就相當於把輸入的x對映到另一個空間上,可以認為這是利用一組不同的基來表示一些向量(也就是原始輸入)。而第二個部分就是利用輸出得到損失函式,這中間那個值起到重要作用。而在深度學習中,我們通常直接使用第二部分。

總結

最後,本文做一個總結。

總結