1. 程式人生 > >在神經網路中提取知識 [Distilling the Knowledge in a Neural Network]

在神經網路中提取知識 [Distilling the Knowledge in a Neural Network]

論文題目:Distilling the Knowledge in a Neural Network

思想總結:
深度神經網路對資訊的提取有著很強的能力,可以從大量的資料中學習到有用的知識,比如學習如何將手寫數字圖片進行0~9的分類。
層數越多(越深),神經單元個數越多的網路,可以在大量的資料中獲取的知識越豐富,能力越強。
然而當我們使用一個十分複雜的網路對一個較大的訓練集進行訓練時,引數眾多,網路模型複雜,計算成本太高而無法部署到大量使用者。

那麼我們是否可以使用一種方法,將複雜的網路獲取的知識,轉移(提取)到一個相對較小,較簡單的網路中呢?使得他們對相同的問題有著相同甚至更優的泛化能力呢?


從而,大大降低計算成本,並使其可以大量部署。

如何將一個已經訓練好的、較大、較複雜的網路模型所學習到的知識,提取到另外一個較小、較簡單的網路模型中,並使得它們對測試資料集有著相同的泛化能力。

在這裡插入圖片描述
如上圖所示,對於一個數字0~4的5個類別的分類問題,我們如何將上面複雜的網路經過訓練後學習到的知識,轉移到下面較為簡單的網路中。

我們使用同樣的資料集訓練下面較為簡單的網路。
(1)要想兩個網路對同一問題有相同的能力,首先我們要確保的是對於同一個輸入x而言,它們輸出的類別結果是一樣的。
對於上面的兩個網路,由於他們都使用softmax函式,所以要確保他們輸出的最大概率的類別號相同。

然而,保證(1)就可以了嗎?我們來看看這樣一些現象。
1:對於像MNIST這樣的任務,複雜的模型幾乎總是以非常高的置信度產生正確的答案,模型學習到的很多資訊存在於概率非常小的比率中

。例如,一種型別的2,被分類為3的概率為10^-6 ,被分類為7的概率為10^−9,而另一個型別的2可能相反。這些不同點是有價值的資訊,它定義了資料上豐富的相似結構(說明了哪些2像3而哪些2更像7),但它對傳遞階段的交叉熵代價函式影響很小,因為概率非常接近於零。
2:在圖片分類問題中,將一輛寶馬錯誤的分類為拖拉機的概率遠大於將其分類為胡蘿蔔的概率。
這些現象說明,當我們進行分類問題時,正確的類別所得到的概率是有用的,並且其他錯誤的類別所得到的概率同樣蘊含這有用的資訊,也是網路通過訓練而學習到的資訊。

所以,在將複雜網路學習到的知識轉移到簡單網路時,不僅僅要學習正確的類別上的概率,還要學習錯誤類別上的概率。因為錯誤類別上的概率同樣體現了模型的泛化方式,和所學習到的知識。雖然在錯誤的類別上的概率較小,但是在有一些錯誤類別上的概率要比另外一些錯誤類別的概率大很多,所以這種大小關係任然體現了複雜模型所學習到的一些知識。

(2)因此,在簡單網路中,我們不僅要使得正確類別的概率最大,並且要保證其他類別的概率與複雜網路中得到的概率相同。即P1~P5都要相同。
我們將此作為簡單模型的目標進行訓練,就可以使得知識較好的轉移。

同理,我們也可以將一個複雜的網路所學習到的知識,轉移到多個簡單的網路中,使得每個簡單的網路獲得複雜網路的一個子功能。
比如:
複雜網路學習到了分類數字0~9的知識。我們可以將其轉移到3個簡單網路中,使得:
網路1:可以對1~4進行分類
網路2:可以對4~7進行分類
網路3:可以對7~9進行分類
再結合網路1,2,3即可完成複雜網路的所有功能,並且大大降低計算量。