1. 程式人生 > >蒸餾神經網路(Distill the Knowledge in a Neural Network)

蒸餾神經網路(Distill the Knowledge in a Neural Network)

本文是閱讀Hinton 大神在2014年NIPS上一篇論文:蒸餾神經網路的筆記,特此說明。此文讀起來很抽象,大篇的論述,鮮有公式和圖表。但是鑑於和我的研究方向:神經網路的壓縮十分相關,因此決定花氣力好好理解一下。 

1、Introduction  

文章開篇用一個比喻來引入網路蒸餾:

    昆蟲作為幼蟲時擅於從環境中汲取能量,但是成長為成蟲後確是擅於其他方面,比如遷徙和繁殖等。

同理神經網路訓練階段從大量資料中獲取網路模型,訓練階段可以利用大量的計算資源且不需要實時響應。然而到達使用階段,神經網路需要面臨更加嚴格的要求包括計算資源限制,計算速度要求等等。

由昆蟲的例子我們可以這樣理解神經網路:一個複雜的網路結構模型是若干個單獨模型組成的集合,或者是一些很強的約束條件下(比如dropout率很高)訓練得到的一個很大的網路模型。一旦複雜網路模型訓練完成,我們便可以用另一種訓練方法:“蒸餾”,把我們需要配置在應用端的縮小模型從複雜模型中提取出來。

      “蒸餾”的難點在於如何縮減網路結構但是把網路中的知識保留下來。知識就是一幅將輸入向量導引至輸出向量的地圖。做複雜網路的訓練時,目標是將正確答案的概率最大化,但這引入了一個副作用:這種網路為所有錯誤答案分配了概率,即使這些概率非常小。 

      我們將複雜模型轉化為小模型時需要注意保留模型的泛化能力,一種方法是利用由複雜模型產生的分類概率作為“軟目標”來訓練小模型。在轉化階段,我們可以用同樣的訓練集或者是另外的“轉化”訓練集。當複雜模型是由簡單模型複合而成時,我們可以用各自的概率分佈的代數或者幾何平均數作為“軟目標”。當“軟目標的”熵值較高時,相對“硬目標”,它每次訓練可以提供更多的資訊和更小的梯度方差,因此小模型可以用更少的資料和更高的學習率進行訓練。 

    像MNIST這種任務,複雜模型可以給出很完美的結果,大部分資訊分佈在小概率的軟目標中。比如一張2的圖片被認為是3的概率為0.000001,被認為是7的概率是0.000000001。Caruana用logits(softmax層的輸入)而不是softmax層的輸出作為“軟目標”。他們目標是是的複雜模型和小模型分別得到的logits的平方差最小。而我們的“蒸餾法”:第一步,提升softmax表示式中的調節引數T,使得複雜模型產生一個合適的“軟目標”  第二步,採用同樣的T來訓練小模型,使得它產生相匹配的“軟目標”

   “轉化”訓練集可以由未打標籤的資料組成,也可以用原訓練集。我們發現使用原訓練集效果很好,特別是我們在目標函式中加了一項之後。這一項的目的是是的小模型在預測實際目標的同時儘量匹配“軟目標”。要注意的是,小模型並不能完全無誤的匹配“軟目標”,而正確結果的犯錯方向是有幫助的。

2、Distillation 


    softmax層的公式如下: 


                                       

  T就是調節引數,一般設為1。T越大,分類的概率分佈越“軟” 

   “蒸餾”最簡單的形式就是:以從複雜模型得到的“軟目標”為目標(這時T比較大),用“轉化”訓練集訓練小模型。訓練小模型時T不變仍然較大,訓練完之後T改為1。 

   當“轉化”訓練集中部分或者所有資料都有標籤時,這種方式可以通過一起訓練模型使得模型得到正確的標籤來大大提升效果。一種實現方法是用正確標籤來修正“軟目標”,但是我們發現一種更好的方法是:對兩個目標函式設定權重係數。第一個目標函式是“軟目標”的交叉熵,這個交叉熵用開始的那個比較大的T來計算。第二個目標函式是正確標籤的交叉熵,這個交叉熵用小模型softmax層的logits來計算且T等於1。我們發現當第二個目標函式權重較低時可以得到最好的結果 


3、Preliminary experiments on MNIST 


  我的理解:將遷移資料集中的3或者7、8去掉是為了證明小模型也能夠從soft target中學得知識。 

4、Experiments on Speech Recognition 


5、Training ensembles of specialists on very big datasets