不用批歸一化也能訓練萬層ResNet,新型初始化方法Fixup瞭解一下
前幾天,我們介紹了「 機器學習領域的七大謠傳 」,其中一個謠傳就是「訓練超 深度殘差網路 怎麼少得了批歸一化(BN)!」。文中介紹了論文《Fixup Initialization: Residual Learning Without Normalization》表明在不引入任何歸一化方法的情況下,通過使用原版 SGD ,可以有效地訓練一個 10,000 層的深度網路。也就是說「訓練超深殘差網路可以不用批歸一化」。
近日,Tesla AI 總監 Andrej Karpathy 也發推提及了這篇論文,這篇論文到底如何,我們一起來看一下吧。
歸一化層是當前最優神經網路架構的重要組成部分。人們普遍認為歸一化層可以穩定訓練、實現較高的學習率、加速收斂並提高泛化能力,儘管其有效的原因仍然是一個活躍的研究課題。這篇論文就挑戰了這一「共識」,認為這些好處並不獨屬於歸一化。
研究者提出了一種新型初始化方法 fixed-update initialization (Fixup),試圖在訓練開始時通過恰當地調整初始化來解決梯度爆炸和梯度消失問題。實驗證明,使用 Fixup 訓練殘差網路與使用歸一化訓練殘差網路一樣穩定,甚至在訓練 10000 層的深度網路時也是如此。此外,經過恰當的正則化後,Fixup 使殘差網路在不使用歸一化的情況下也能在影象分類和機器翻譯任務中達到當前最優效能。
左圖是 ResNet,紅色為批歸一化層。中間圖是堆疊在一起也能穩定訓練的簡單 Fixup 網路(移除了偏置項)。右圖是新增偏置項引數後的 Fixup。
作者在論文開頭就提出了兩個問題:
-
沒有歸一化,深度殘差網路可以得到可靠的訓練嗎?
-
如果可以的話,那麼使用和不使用歸一化的深度殘差網路,在相同學習率和最優化方法時,收斂速率和泛化效能能否一致?
而這篇論文中給出的答案是兩個「Yes」!
該研究解釋了給出這個答案的原因:
-
歸一化為什麼有助於模型訓練?論文作者推匯出殘差網路在初始化時,梯度範數的下界,從而解釋了為什麼在使用標準初始化時,歸一化技術對於用大學習率訓練深度殘差網路是必需的。
-
關於不使用歸一化的模型訓練,作者提出了一種新初始化方法 Fixup。通過調整網路架構對殘差分支的標準初始化進行重縮放,Fixup 可使深度殘差網路在不使用歸一化技術的情況下以最大學習率進行穩定訓練。
作者還在影象分類和機器翻譯任務上對以上第二點結論進行了驗證:
-
影象分類:作者在影象分類基準資料集 CIFAR-10 上訓練 Wide-ResNet、在 ImageNet 資料集上訓練 ResNet,但將批歸一化技術替換成了 Fixup,結果發現經過恰當正則化的 Fixup 效能堪比使用歸一化技術的精調基線模型。
-
機器翻譯:作者在機器翻譯基準IWSLT 和 WMT 上訓練 Transformer 模型,但將層歸一化替換成了 Fixup,結果發現新模型優於基線模型,且同樣的架構輸出結果卻達到了新高。
深度網路標準初始化的問題
標準初始化方法嘗試設定網路的初始引數,以使啟用函式不會消失也不會爆炸。但是,據觀察在沒有歸一化技術的時候,標準初始化無法恰當處理殘差連線的梯度流,從而導致梯度爆炸。
作者分析後指出了深度殘差網路標準初始化的失敗模式:某些層的梯度範數下界會隨著網路深度增加而無限增長,即梯度爆炸。具體推導過程參見論文第 2 章。
新型初始化方法 Fixup
作者同時指出擺脫該失敗模式未必就能帶來成功的訓練,畢竟我們所關心的是把整個網路作為函式,而不是一個層或一個網路塊。因此,作者提出了一種新型初始化方法 Fixup,它使用自上而下的設計,通過調整標準初始化來確保網路函式的更新(Gradient)保持在恰當範圍內。作者用 η 表示學習率,將目標設定為:
換言之,其目標是設計一種初始化,使網路函式的 SGD 更新保持在合適的尺度內,且獨立於網路深度。
總結來看,該研究提出的新型初始化方法 Fixup 可使在不使用歸一化技術的情況下完成殘差網路訓練。其原理如下:
-
1. 將分類層和每個殘差分支的最後一層初始化為 0。
-
2. 使用標準方法對其他層執行初始化,然後按
縮放殘差分支中的權重層。
-
3. 在每個分支中新增一個標量乘數((initialized at 1),在每個卷積、線性和元素級啟用層前面新增一個標量偏差((initialized at 0)。
其中規則 2 是必要的。
實驗結果
該研究在深度網路和影象分類、機器翻譯任務上對這一新方法進行了測試。
深度網路
以預設學習率在 CIFAR-10 資料集上進行第一個 epoch 訓練後,Fixup 的效能與批歸一化不相上下,甚至對 10000 層的深度網路也是如此。
圖 3:不同方法以預設學習率在 CIFAR-10 資料集上訓練一個 epoch,訓練後的殘差網路深度和測試準確率對比。從圖中可以看到 Fixup 在訓練非常深的網路時可以與批歸一化具備同樣的學習率(越高越好)。
影象識別
研究者在影象分類應用上將批歸一化替換成 Fixup,來測試 Fixup 的能力。
表 1:使用 ResNet-110 在 CIFAR-10 資料集上的結果(5 次訓練的平均值,分值越低越好)。
Fixup 能夠在 CIFAR-10 上以高學習速率訓練一個 110 層的深度殘差網路,得到的測試集表現和利用批歸一化訓練的同結構網路效果相當。
作者還在 ImageNet 資料集上進行了測試。
表 2:使用 ResNet 架構在 ImageNet 資料集上的測試結果(分值越低越好)。
機器翻譯
為了展示 Fixup 的通用性,研究者還使用 Transformer 在機器翻譯任務中進行了測試。
表 3:在機器翻譯任務上對比 Fixup 和 LayerNorm(分值越高越好)。
論文連結:https://arxiv.org/pdf/1901.09321v1.pdf