1. 程式人生 > >三步理解--門控迴圈單元(GRU),TensorFlow實現

三步理解--門控迴圈單元(GRU),TensorFlow實現

1. 什麼是GRU

在迴圈神經⽹絡中的梯度計算⽅法中,我們發現,當時間步數較⼤或者時間步較小時,迴圈神經⽹絡的梯度較容易出現衰減或爆炸。雖然裁剪梯度可以應對梯度爆炸,但⽆法解決梯度衰減的問題。通常由於這個原因,迴圈神經⽹絡在實際中較難捕捉時間序列中時間步距離較⼤的依賴關係。

門控迴圈神經⽹絡(gated recurrent neural network)的提出,正是為了更好地捕捉時間序列中時間步距離較⼤的依賴關係。它通過可以學習的⻔來控制資訊的流動。其中,門控迴圈單元(gatedrecurrent unit,GRU)是⼀種常⽤的門控迴圈神經⽹絡。

2. ⻔控迴圈單元

2.1 重置門和更新門

GRU它引⼊了重置⻔(reset gate)和更新⻔(update gate)的概念,從而修改了迴圈神經⽹絡中隱藏狀態的計算⽅式。

門控迴圈單元中的重置⻔和更新⻔的輸⼊均為當前時間步輸⼊ \(X_t\) 與上⼀時間步隱藏狀態\(H_{t-1}\),輸出由啟用函式為sigmoid函式的全連線層計算得到。 如下圖所示:

具體來說,假設隱藏單元個數為 h,給定時間步 t 的小批量輸⼊ \(X_t\in_{}\mathbb{R}^{n*d}\)(樣本數為n,輸⼊個數為d)和上⼀時間步隱藏狀態 \(H_{t-1}\in_{}\mathbb{R}^{n*h}\)。重置⻔ \(H_t\in_{}\mathbb{R}^{n*h}\) 和更新⻔ \(Z_t\in_{}\mathbb{R}^{n*h}\) 的計算如下:

\[R_t=\sigma(X_tW_{xr}+H_{t-1}W_{hr}+b_r)\]

\[Z_t=\sigma(X_tW_{xz}+H_{t-1}W_{hz}+b_z)\]

sigmoid函式可以將元素的值變換到0和1之間。因此,重置⻔ \(R_t\) 和更新⻔ \(Z_t\) 中每個元素的值域都是[0, 1]。

2.2 候選隱藏狀態

接下來,⻔控迴圈單元將計算候選隱藏狀態來輔助稍後的隱藏狀態計算。我們將當前時間步重置⻔的輸出與上⼀時間步隱藏狀態做按元素乘法(符號為)。如果重置⻔中元素值接近0,那麼意味著重置對應隱藏狀態元素為0,即丟棄上⼀時間步的隱藏狀態。如果元素值接近1,那麼表⽰保留上⼀時間步的隱藏狀態。然後,將按元素乘法的結果與當前時間步的輸⼊連結,再通過含啟用函式tanh的全連線層計算出候選隱藏狀態,其所有元素的值域為[-1,1]。

具體來說,時間步 t 的候選隱藏狀態 \(\tilde{H}\in_{}\mathbb{R}^{n*h}\) 的計算為:

\[\tilde{H}_t=tanh(X_tW_{xh}+(R_t⊙H_{t-1})W_{hh}+b_h)\]

從上⾯這個公式可以看出,重置⻔控制了上⼀時間步的隱藏狀態如何流⼊當前時間步的候選隱藏狀態。而上⼀時間步的隱藏狀態可能包含了時間序列截⾄上⼀時間步的全部歷史資訊。因此,重置⻔可以⽤來丟棄與預測⽆關的歷史資訊。

2.3 隱藏狀態

最後,時間步t的隱藏狀態 \(H_t\in_{}\mathbb{R}^{n*h}\) 的計算使⽤當前時間步的更新⻔\(Z_t\)來對上⼀時間步的隱藏狀態 \(H_{t-1}\) 和當前時間步的候選隱藏狀態 \(\tilde{H}_t\) 做組合:

值得注意的是,更新⻔可以控制隱藏狀態應該如何被包含當前時間步資訊的候選隱藏狀態所更新,如上圖所⽰。假設更新⻔在時間步 \(t^{′}到t(t^{′}<t)\) 之間⼀直近似1。那麼,在時間步 \(t^{′}到t\) 間的輸⼊資訊⼏乎沒有流⼊時間步 t 的隱藏狀態\(H_t\)實際上,這可以看作是較早時刻的隱藏狀態 \(H_{t^{′}-1}\) 直通過時間儲存並傳遞⾄當前時間步 t。這個設計可以應對迴圈神經⽹絡中的梯度衰減問題,並更好地捕捉時間序列中時間步距離較⼤的依賴關係。

我們對⻔控迴圈單元的設計稍作總結:

  • 重置⻔有助於捕捉時間序列⾥短期的依賴關係;
  • 更新⻔有助於捕捉時間序列⾥⻓期的依賴關係。

3. 程式碼實現GRU

MNIST--GRU實現

【機器學習通俗易懂系列文章】

4. 參考文獻

《動手學--深度學習》


作者:@mantchs

GitHub:https://github.com/NLP-LOVE/ML-NLP

歡迎大家加入討論!共同完善此專案!群號:【541954936】