1. 程式人生 > >Adam和學習率衰減(learning rate decay)

Adam和學習率衰減(learning rate decay)

目錄

  • 梯度下降法更新引數
  • Adam 更新引數
  • Adam + 學習率衰減
    • Adam 衰減的學習率
  • References

本文先介紹一般的梯度下降法是如何更新引數的,然後介紹 Adam 如何更新引數,以及 Adam 如何和學習率衰減結合。

梯度下降法更新引數

梯度下降法引數更新公式:
\[ \theta_{t+1} = \theta_{t} - \eta \cdot \nabla J(\theta_t) \]

其中,\(\eta\) 是學習率,\(\theta_t\) 是第 \(t\) 輪的引數,\(J(\theta_t)\) 是損失函式,\(\nabla J(\theta_t)\) 是梯度。

在最簡單的梯度下降法中,學習率 \(\eta\) 是常數,是一個需要實現設定好的超引數,在每輪引數更新中都不變,在一輪更新中各個引數的學習率也都一樣。

為了表示簡便,令 \(g_t = \nabla J(\theta_t)\),所以梯度下降法可以表示為:
\[ \theta_{t+1} = \theta_{t} - \eta \cdot g_t \]

Adam 更新引數

Adam,全稱 Adaptive Moment Estimation,是一種優化器,是梯度下降法的變種,用來更新神經網路的權重。

Adam 更新公式:
\[ \begin{aligned} m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\ v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2} \\ \hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\ \hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \\ \theta_{t+1}&=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t} \end{aligned} \]

在 Adam 原論文以及一些深度學習框架中,預設值為 \(\eta = 0.001\),\(\beta_1 = 0.9\),\(\beta_2 = 0.999\),\(\epsilon = 1e-8\)。其中,\(\beta_1\) 和 \(\beta_2\) 都是接近 1 的數,\(\epsilon\) 是為了防止除以 0。\(g_{t}\) 表示梯度。

咋一看很複雜,接下一一分解:

  • 前兩行:
    \[ \begin{aligned} m_{t} &=\beta_{1} m_{t-1}+\left(1-\beta_{1}\right) g_{t} \\ v_{t} &=\beta_{2} v_{t-1}+\left(1-\beta_{2}\right) g_{t}^{2} \end{aligned} \]

這是對梯度和梯度的平方進行滑動平均,即使得每次的更新都和歷史值相關。

  • 中間兩行:
    \[ \begin{aligned} \hat{m}_{t} &=\frac{m_{t}}{1-\beta_{1}^{t}} \\ \hat{v}_{t} &=\frac{v_{t}}{1-\beta_{2}^{t}} \end{aligned} \]
    這是對初期滑動平均偏差較大的一個修正,叫做 bias correction,當 \(t\) 越來越大時,\(1-\beta_{1}^{t}\) 和 \(1-\beta_{2}^{t}\) 都趨近於 1,這時 bias correction 的任務也就完成了。

  • 最後一行:
    \[ \theta_{t+1}=\theta_{t}-\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon} \hat{m}_{t} \]

這是引數更新公式。

學習率為 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\),每輪的學習率不再保持不變,在一輪中,每個引數的學習率也不一樣了,這是因為 \(\eta\) 除以了每個引數 \(\frac{1}{1- \beta_2} = 1000\) 輪梯度均方和的平方根,即 \(\sqrt{\frac{1}{1000}\sum_{k = t-999}^{t}g_k^2}\)。而每個引數的梯度都是不同的,所以每個引數的學習率即使在同一輪也就不一樣了。(可能會有疑問,\(t\) 前面沒有 999 輪更新怎麼辦,那就有多少輪就算多少輪,這個時候還有 bias correction 在。)

而引數更新的方向也不只是當前輪的梯度 \(g_t\) 了,而是當前輪和過去共 \(\frac{1}{1- \beta_1} = 10\) 輪梯度的平均。

有關滑動平均的理解,可以參考我之前的部落格:理解滑動平均(exponential moving average)。

Adam + 學習率衰減

在 StackOverflow 上有一個問題 Should we do learning rate decay for adam optimizer - Stack Overflow,我也想過這個問題,對 Adam 這些自適應學習率的方法,還應不應該進行 learning rate decay?

我簡單的做了個實驗,在 cifar-10 資料集上訓練 LeNet-5 模型,一個採用學習率衰減 tf.keras.callbacks.ReduceLROnPlateau(patience=5),另一個不用。optimizer 為 Adam 並使用預設的引數,\(\eta = 0.001\)。結果如下:


加入學習率衰減和不加兩種情況在 test 集合上的 accuracy 分別為: 0.5617 和 0.5476。(實驗結果取了兩次的平均,實驗結果的偶然性還是有的)

通過上面的小實驗,我們可以知道,學習率衰減還是有用的。(當然,這裡的小實驗僅能代表一小部分情況,想要說明學習率衰減百分之百有效果,得有理論上的證明。)

當然,在設定超引數時就可以調低 \(\eta\) 的值,使得不用學習率衰減也可以達到很好的效果,只不過引數更新變慢了。

將學習率從預設的 0.001 改成 0.0001,epoch 增大到 120,實驗結果如下所示:

加入學習率衰減和不加兩種情況在 test 集合上的 accuracy 分別為: 0.5636 和 0.5688。(三次實驗平均,實驗結果仍具有偶然性)

這個時候,使用學習率衰減帶來的影響可能很小。

那麼問題來了,Adam 做不做學習率衰減呢?
我個人會選擇做學習率衰減。(僅供參考吧。)在初始學習率設定較大的時候,做學習率衰減比不做要好;而當初始學習率設定就比較小的時候,做學習率衰減似乎有點多餘,但從 val set 上的效果看,做了學習率衰減還是可以有丁點提升的。

ReduceLROnPlateau 在 val_loss 正常下降的時候,對學習率是沒有影響的,只有在 patience(預設為 10)個 epoch 內,val_loss 都不下降 1e-4 或者直接上升了,這個時候降低學習率確實是可以很明顯提升模型訓練效果的,在 val_acc 曲線上看到一個快速上升的過程。對於其它型別的學習率衰減,這裡沒有過多地介紹。

Adam 衰減的學習率

從上述學習率曲線來看,Adam 做學習率衰減,是對 \(\eta\) 進行,而不是對 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 進行,但有區別嗎?

學習率衰減一般如下:

  • exponential_decay:
    decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)

  • natural_exp_decay:
    decayed_learning_rate = learning_rate * exp(-decay_rate * global_step / decay_steps)

  • ReduceLROnPlateau
    如果被監控的值(如‘val_loss’)在 patience 個 epoch 內都沒有下降,那麼學習率衰減,乘以一個 factor
    decayed_learning_rate = learning_rate * factor

這些學習率衰減都是直接在原學習率上乘以一個 factor ,對 \(\eta\) 或對 \(\frac{\eta}{\sqrt{\hat{v}_{t}}+\epsilon}\) 操作,結果都是一樣的。

References

An overview of gradient descent optimization algorithms -- Sebastian Ruder
Should we do learning rate decay for adam optimizer - Stack Overflow
Tensorflow中learning rate decay的奇技淫巧 -- Elev