1. 程式人生 > >知識蒸餾(Distillation)相關論文閱讀(1)——Distilling the Knowledge in a Neural Network(以及程式碼復現)

知識蒸餾(Distillation)相關論文閱讀(1)——Distilling the Knowledge in a Neural Network(以及程式碼復現)

———————————————————————————————

《Distilling the Knowledge in a Neural Network》

Geoffrey Hintion

以往為了提高模型表現所採取的方法是對同一個資料集訓練出多個模型,再對預測結果進行平均;但通常這樣計算量過大。

引出一種模型壓縮技術:Distillation;以及介紹了一種由一個或多個完整模型(full models)以及針對/細節/特殊模型(specialist models)的組合,來學習區分僅僅是完整模型會混淆的細節。比起expects的混合,這些specialist models可以並行訓練,並且訓練起來更快速。

Intorduce:

引喻:昆蟲在幼蟲時擅於從環境中汲取能量,但是成長為成蟲後擅於其他方面,比如遷徙和繁殖等。

我們常常用相似的網路來訓練不同需求的問題:對於像語音和物件識別這樣的任務,訓練必須從非常大的、高度冗餘的資料集中提取結構,但是它不需要實時操作,並且可以使用大量的計算。

對網路新的理解:

如果能夠簡單的從資料中提取結構,我們應該樂意於訓練非常複雜的網路。複雜網路可以理解為是由一些單個的模型或者是一個由強約束條件(例如dropout)訓練得到的大型模型。一旦訓練得到複雜網路,我們可以用不同的訓練(即‘distillation’)來將知識從複雜網路轉化成更使用於應用拓展的小模型。

知識蒸餾的難點:

如何改變網路結構但是同時保留同樣的知識。拋開知識的例項化,知識可以看做為一個從輸入向量到輸出向量的只是地圖。

對於複雜模型在大量類中區分的正常訓練的目標,就是在於最大化正確答案的平均對數概率。但是缺點在於學習過程中會為所有的錯誤答案分配了概率,儘管這個概率很小。錯誤答案的相對概率反映了一個複雜網路是如何變得一般(泛化能力)的。

轉化的可能性:

訓練是的目標函式需要最大限度的反映正確目標,儘管如此模型還是通常被訓練得在測試資料上表現的最優,但真實目標其實是需要在新的資料集上表現良好。當我們從大模型到小模型做知識蒸餾的時候,我們可以像訓練大模型一下訓練好小模型。一個複雜的模型具有良好的泛化能力是因為它通常是許多不同模型的平均,這樣通過蒸餾的方式訓練出的小模型會比傳統訓練出來的小模型在測試集上表現更好。

利用複雜模型的泛化能力轉化為小模型的一種顯而易見的方法是:

將複雜模型生成的分類概率作為訓練小模型的“軟目標”。在這個轉移階段,我們可以使用相同的訓練集或其他的“轉化”訓練集。當複雜的模型是由一組更簡單的模型組成時,我們可以使用它們各自的預測分佈的算術或幾何平均值作為軟目標。當軟目標具有較高的熵值時,相對“硬目標,每個訓練用例所提供的資訊要比硬性指標多得多,它每次訓練可以提供更多的資訊和更小的梯度方差。因此小模型可以比原始的複雜模型更容易地訓練,而且使用的學習速率要高得多。

MNIST例項:

像MNIST這種任務,複雜模型可以給出很完美的結果,大部分資訊分佈在小概率的軟目標中。比如一張2的圖片被認為是3的概率為0.000001,被認為是7的概率是0.000000001。Caruana用logits(softmax層的輸入)而不是softmax層的輸出作為“軟目標”。目標是使得複雜模型和小模型分別得到的logits的平方差最小。“蒸餾法”:第一步,提升softmax表示式中的調節引數T,使得複雜模型產生一個合適的“軟目標”  第二步,採用同樣的T來訓練小模型,使得它產生相匹配的“軟目標。

並且發現,比起使用未標註的資料集原始的訓練集更好,尤其是在目標函式中加了一項的時候,能夠讓小模型預測正確的同時儘量匹配軟目標。但小模型是不能完全匹配軟目標的,正確結果的錯誤方向反而是有幫助。

Distillation:

修改後的softmax公式為:

T就是一個調節引數,通常為1;T的數值越大則所有類的分佈越‘軟’(平緩)。

一個簡單的知識蒸餾的形式是:用複雜模型得到的“軟目標”為目標(在softmax中T較大),用“轉化”訓練集訓練小模型。訓練小模型時T不變仍然較大,訓練完之後T改為1。 

當正確的標籤是所有的或部分的傳輸集時,這個方法可以通過訓練被蒸餾的模型產生正確的標籤。一種方法是使用正確的標籤來修改軟目標,但是我們發現更好的方法是簡單地使用兩個不同目標函式的加權平均值。第一個目標函式是帶有軟目標的交叉熵,這種交叉熵是在蒸餾模型的softmax中使用相同的T計算的,用於從繁瑣的模型中生成軟目標。第二個目標函式是帶有正確標籤的交叉熵。這是在蒸餾模型的softmax中使用完全相同的邏輯,但在T=1下計算。我們發現,在第二個目標函式中,使用一個較低權重的條件,得到了最好的結果。由於軟目標尺度所產生的梯度的大小為1/T^2,所以在使用硬的和軟的目標時將它們乘以T^2是很重要的。這確保了在使用T時,硬和軟目標的相對貢獻基本保持不變。

———————————————————————————————

筆者個人理解以及Pytorch程式碼實現:

可能讀者看到此處發現並不特別清楚論文中具體的蒸餾步驟以及T引數的意義,以下對幾個關鍵點的理解做出個人的解釋,歡迎指導和討論:

      1. T引數是什麼?有什麼作用?

        T引數為了對應蒸餾的概念,在論文中叫的是Temperature,也就是蒸餾的溫度。T越高對應的分佈概率越平緩,為什麼要使得分佈概率變平緩?舉一個例子,假設你是每次都是進行負重登山,雖然過程很辛苦,但是當有一天你取下負重,正常的登山的時候,你就會變得非常輕鬆,可以比別人登得高登得遠。

        同樣的,在這篇文章裡面的T就是這個負重包,我們知道對於一個複雜網路來說往往能夠得到很好的分類效果,錯誤的概率比正確的概率會小很多很多,但是對於一個小網路來說它是無法學成這個效果的。我們為了去幫助小網路進行學習,就在小網路的softmax加一個T引數,加上這個T引數以後錯誤分類再經過softmax以後輸出會變大(softmax中指數函式的單增特性,這裡不做具體解釋),同樣的正確分類會變小。這就人為的加大了訓練的難度,一旦將T重新設定為1,分類結果會非常的接近於大網路的分類效果

     2. soft target(“軟目標”)是什麼?

        soft就是對應的帶有T的目標,是要儘量的接近於大網路加入T後的分佈概率。

     3. hard target(“硬目標”)是什麼?

         hard就是正常網路訓練的目標,是要儘量的完成正確的分類。

     4. 兩個目標函式究竟是什麼?

兩個目標函式也就是對應的上面的soft target和hard target。這個體現在Student Network會有兩個loss,分別對應上面兩個問題求得的交叉熵,作為小網路訓練的loss function。

     5. 具體蒸餾是如何訓練的?

        Teacher:  對softmax(T=20)的輸出與原始label求loss。

        Student:(1)對softmax(T=20)的輸出與Teacher的softmax(T=20)的輸出求loss1。

                         (2)對softmax(T=1)的輸出與原始label求loss2。

                         (3)loss = loss1+loss2

在弄清楚上面的問題以後,我們就可以進行程式碼復現了,筆者選擇的是Pytorch,具體程式碼和效果對比會稍後整理一下上傳至本人的github,先佔坑~

相關推薦

知識蒸餾Distillation相關論文閱讀1——Distilling the Knowledge in a Neural Network以及程式碼復現

———————————————————————————————《Distilling the Knowledge in a Neural Network》Geoffrey Hintion以往為了提高模型表現所採取的方法是對同一個資料集訓練出多個模型,再對預測結果進行平均;但通

在神經網路中提取知識 [Distilling the Knowledge in a Neural Network]

論文題目:Distilling the Knowledge in a Neural Network 思想總結: 深度神經網路對資訊的提取有著很強的能力,可以從大量的資料中學習到有用的知識,比如學習如何將手寫數字圖片進行0~9的分類。 層數越多(越深),神經單元個數越多的網路,可以在大

蒸餾神經網路(Distill the Knowledge in a Neural Network) 論文筆記 蒸餾神經網路(Distill the Knowledge in a Neural Network) 論文筆記

轉 蒸餾神經網路(Distill the Knowledge in a Neural Network) 論文筆記 2017年08月06日 16:19:48 haoji00

蒸餾神經網路(Distill the Knowledge in a Neural Network)

本文是閱讀Hinton 大神在2014年NIPS上一篇論文:蒸餾神經網路的筆記,特此說明。此文讀起來很抽象,大篇的論述,鮮有公式和圖表。但是鑑於和我的研究方向:神經網路的壓縮十分相關,因此決定花氣力好好理解一下。  1、Introduction   文章開篇用一個比喻來引

(論文閱讀筆記1)Collaborative Metric Learning(一)WWW2017

一、摘要     度量學習演算法產生的距離度量捕獲資料之間的重要關係。這裡,我們將度量學習和協同過濾聯絡起來,提出了協同度量學習(CML),它可以學習出一個共同的度量空間來編碼使用者偏好和user-user 和 item-item的相似度。 二、背景

論文閱讀筆記四十五:Region Proposal by Guided AnchoringCVPR2019

分類 cascade 忽略 出了 advance ive 獲得 ams ons 論文原址:https://arxiv.org/abs/1901.03278 github:code will be available 摘要 區域anchor是現階段目標檢

論文閱讀筆記五十四:Gradient Harmonized Single-stage DetectorCVPR2019

advance splay 出發 產生 sigmoid for 問題 信息 eee 論文原址:https://arxiv.org/pdf/1811.05181.pdf github:https://github.com/libuyu/GHM_Detection

論文筆記——An online EEG-based brain-computer interface for controlling hand grasp using an adaptive probabilistic neural network10年被引用66次

不同 -s evel 模型 his ren 虛擬 dem virt 題目:利用自適應概率網絡設計一種在線腦機接口樓方法控制手部抓握 概要:這篇文章提出了一種新的腦機接口方法,控制手部,系列手部抓握動作和張開在虛擬現實環境中。這篇文章希望在現實生活中利用腦機接口技術控制抓握。

經典論文閱讀——DeepFashion: Powering Robust Clothes Recognition and Retrieval with Rich Annotations CVPR 2

DeepFashion: Powering Robust Clothes Recognition and  Retrieval with Rich Annotations (CVPR 2016) link:http://mmlab.ie.cuhk.edu.hk/proj

540. Single Element in a Sorted ArrayLeetCode

find you span code ace urn dup duplicate which Given a sorted array consisting of only integers where every element appears twice except

Leetcode#557. Reverse Words in a String III反轉字符串中的單詞 III

etc println urn pen eof reverse 同時 string i++ 題目描述 給定一個字符串,你需要反轉字符串中每個單詞的字符順序,同時仍保留空格和單詞的初始順序。 示例 1: 輸入: "Let's take LeetCode co

Make your own neural networkPython神經網路程式設計

  這本書應該算我第一本深度學習的程式碼入門書了吧,之前看阿里云云棲社和景略集智都有推薦這本書就去看了,   成功建立了自己的第一個神經網路,也瞭解一些關於深度學習的內容,再加上這學期的概率論與數理統計的課,   現在再來看李大大的機器學習課程,終於能看懂LogisticsRegression概率那部分公

Make your own neural networkPython神經網路程式設計

前兩篇程式碼寫了初始化與查詢,知道了S函式,初始權重矩陣。以及神經網路的計算原理,總這一篇起就是最重要的神經網路的訓練了。 神經網路的訓練簡單點來講就是讓輸出的東西更加接近我們的預期。比如我們輸出的想要是1,但是輸出了-0.5,明顯不是 我們想要的。 誤差=(期望的數值)-(實際輸出),那麼我們的誤差就是

1151 LCA in a Binary Tree30 point(s)

1151 LCA in a Binary Tree(30 point(s)) The lowest common ancestor (LCA) of two nodes U and V in a tree is the deepest node that has both U and V

論文閱讀筆記】Deep Learning based Recommender System: A Survey and New Perspectives

【論文閱讀筆記】Deep Learning based Recommender System: A Survey and New Perspectives 2017年12月04日 17:44:15 cskywit 閱讀數:1116更多 個人分類: 機器學習

CS229 6.17 Neurons Networks convolutional neural networkcnn

之前所講的影象處理都是小 patchs ,比如28*28或者36*36之類,考慮如下情形,對於一副1000*1000的影象,即106,當隱層也有106節點時,那麼W(1)的數量將達到1012級別,為了減少引數規模,加快訓練速度,CNN應運而生。CNN就像辟邪劍譜一樣,正常人練得很挫,一旦自宮後,就變得很厲害。

1151 LCA in a Binary Tree 30 分

The lowest common ancestor (LCA) of two nodes U and V in a tree is the deepest node that has both U and V as descendants. Given any two nodes in a

1151 LCA in a Binary Tree30 分

題目大意:給出中序序列和先序序列,再給出兩個點,求這兩個點的最近公共祖先。 解題思路:不用建樹~已知某個樹的根結點,若a和b在根結點的左邊,則a和b的最近公共祖先在當前子樹根結點的左子樹尋

1151 LCA in a Binary Tree 30 分(cj)

1151 LCA in a Binary Tree (30 分) The lowest common ancestor (LCA) of two nodes U and V in a tree is the deepest node that has both U and

LeetCode——Peak Index in a Mountain Array852

Let's call an array A a mountain if the following properties hold: A.length >= 3 There exists some 0 < i < A.length - 1 such th