1. 程式人生 > >【深度學習】說一說深度學習的weight initialization

【深度學習】說一說深度學習的weight initialization

  先說一下幹嗎了?

感謝:

https://zhuanlan.zhihu.com/p/34879333 Batch Normalization原理與實戰

https://www.leiphone.com/news/201703/3qMp45aQtbxTdzmK.html 谷歌工程師:聊一聊深度學習的weight initialization

 

  深度學習模型訓練的過程本質是對weight(即引數 W)進行更新,這需要每個引數有相應的初始值。有人可能會說:“引數初始化有什麼難點?直接將所有weight初始化為0或者初始化為隨機數!”對一些簡單的機器學習模型,或當optimization function是convex function時,這些簡單的方法確實有效。

 

初始化為0的可行性?

為什麼將所有W初始化為0是錯誤的呢?是因為如果所有的引數都是0,那麼所有神經元的輸出都將是相同的,那在back propagation的時候同一層內所有神經元的行為也是相同的 --- gradient相同,weight update也相同。這顯然是一個不可接受的結果。

 

可行的幾種初始化方式

 

pre-training

 pre-training是早期訓練神經網路的有效初始化方法,一個便於理解的例子是先使用greedy layerwise auto-encoder做unsupervised pre-training,然後再做fine-tuning。具體過程可以參見UFLDL的一個

tutorialhttp://ufldl.stanford.edu/wiki/index.php/Stacked_Autoencoders

因為這不是本文重點,就在這裡簡略的說一下:

  • pre-training階段,將神經網路中的每一層取出,構造一個auto-encoder做訓練,使得輸入層和輸出層保持一致。在這一過程中,引數得以更新,形成初始值

  • fine-tuning階段,將pre-train過的每一層放回神經網路,利用pre-train階段得到的引數初始值和訓練資料對模型進行整體調整。在這一過程中,引數進一步被更新,形成最終模型。

隨著資料量的增加以及activation function (參見我的另一篇文章

https://zhuanlan.zhihu.com/p/25110450 )的發展,pre-training的概念已經漸漸發生變化。目前,從零開始訓練神經網路時我們也很少採用auto-encoder進行pre-training,而是直奔主題做模型訓練。不想從零開始訓練神經網路時,我們往往選擇一個已經訓練好的在任務A上的模型(稱為pre-trained model),將其放在任務B上做模型調整(稱為fine-tuning)。其實就是遷移學習吧

random initialization

隨機初始化是很多人目前經常使用的方法,然而這是有弊端的,一旦隨機分佈選擇不當,就會導致網路優化陷入困境。下面舉幾個例子。

核心程式碼見下方,完整程式碼請參見原作者的Githubhttps://github.com/feixia586/zhihu_material/tree/master/weight_initialization

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

這裡我們建立了一個10層的神經網路,非線性變換為tanh,每一層的引數都是隨機正態分佈,均值為0,標準差為0.01。下圖給出了每一層輸出值分佈的直方圖。

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

隨著層數的增加,我們看到輸出值迅速向0靠攏,在後幾層中,幾乎所有的輸出值 x 都很接近0!回憶優化神經網路的back propagation演算法,根據鏈式法則,gradient等於當前函式的gradient乘以後一層的gradient,這意味著輸出值 x 是計算gradient中的乘法因子,直接導致gradient很小,使得引數難以被更新!

讓我們將初始值調大一些:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

均值仍然為0,標準差現在變為1,下圖是每一層輸出值分佈的直方圖:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

幾乎所有的值集中在-1或1附近,神經元saturated了!注意到tanh在-1和1附近的gradient都接近0,這同樣導致了gradient太小,引數難以被更新。

Xavier initialization

Xavier initialization可以解決上面的問題!其初始化方式也並不複雜。Xavier初始化的基本思想是保持輸入和輸出的方差一致,這樣就避免了所有輸出值都趨向於0。注意,為了問題的簡便,Xavier初始化的推導過程是基於線性函式的,但是它在一些非線性神經元中也很有效。讓我們試一下:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

Woohoo!輸出值在很多層之後依然保持著良好的分佈,這很有利於我們優化神經網路!之前談到Xavier initialization是線上性函式上推導得出,這說明它對非線性函式並不具有普適性,所以這個例子僅僅說明它對tanh很有效,那麼對於目前最常用的ReLU神經元呢(關於不同非線性神經元的比較請參考這裡https://zhuanlan.zhihu.com/p/25110450)?繼續做一下實驗:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

前面看起來還不錯,後面的趨勢卻是越來越接近0。

但,,,,,幸運的是,He initialization可以用來解決ReLU初始化的問題。

 

He initialization

He initialization的思想是:在ReLU網路中,假定每一層有一半的神經元被啟用,另一半為0,所以,要保持variance不變,只需要在Xavier的基礎上再除以2:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

看起來效果非常好,推薦在ReLU網路中使用!

Batch Normalization Layer

Batch Normalization是一種巧妙而粗暴的方法來削弱bad initialization的影響,其基本思想是:If you want it, just make it!

我們想要的是在非線性activation之前,輸出值應該有比較好的分佈(例如高斯分佈),以便於back propagation時計算gradient,更新weight。Batch Normalization將輸出值強行做一次Gaussian Normalization和線性變換:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

Batch Normalization中所有的操作都是平滑可導,這使得back propagation可以有效執行並學到相應的引數γ,β。需要注意的一點是Batch Normalization在training和testing時行為有所差別。Training時μβ和σβ由當前batch計算得出;在Testing時μβ和σβ應使用Training時儲存的均值或類似的經過處理的值,而不是由當前batch計算。

隨機初始化,無Batch Normalization:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

隨機初始化,有Batch Normalization:

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

è°·æ­å·¥ç¨å¸ï¼èä¸è深度学习çweight initialization

很容易看到,Batch Normalization的效果非常好,推薦使用!

參考資料

Xavier initialization是由Xavier Glorot et al.在2010年提出,He initialization是由Kaiming He et al.在2015年提出,Batch Normalization是由Sergey Ioffe et al.在2015年提出。

1、Xavier Glorot et al., Understanding the Difficult of Training Deep Feedforward Neural Networks

2、Kaiming He et al., Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classfication

3、Sergey Ioffe et al., Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

4、Standord CS231n