RNN - LSTM - GRU
迴圈神經網路 (Recurrent Neural Network,RNN) 是一類具有短期記憶能力的神經網路,因而常用於序列建模。本篇先總結 RNN 的基本概念,以及其訓練中時常遇到梯度爆炸和梯度消失問題,再引出 RNN 的兩個主流變種 —— LSTM 和 GRU。
Vanilla RNN
Vanilla RNN 的主體結構:

上圖中 \(\bf{X, h, y}\) 都是向量,公式如下:
\[ % <![CDATA[ \begin{align} \textbf{h}_{t} &= f_{\textbf{W}}\left(\textbf{h}_{t-1}, \textbf{x}_{t} \right) \tag{1} \\ \textbf{h}_{t} &= f\left(\textbf{W}_{hx}\textbf{x}_{t} + \textbf{W}_{hh}\textbf{h}_{t-1} + \textbf{b}_{h}\right) \tag{2a} \\ \textbf{h}_{t} &= \textbf{tanh}\left(\textbf{W}_{hx}\textbf{x}_{t} + \textbf{W}_{hh}\textbf{h}_{t-1} + \textbf{b}_{h}\right) \tag{2b} \\ \hat{\textbf{y}}_{t} &= \textbf{softmax}\left(\textbf{W}_{yh}\textbf{h}_{t} + \textbf{b}_{y}\right) \tag{3} \end{align} %]]> \]
其中
\(\textbf{W}_{hx} \in \mathbb{R}^{h \times x}, \; \textbf{W}_{hh} \in \mathbb{R}^{h \times h}, \; \textbf{W}_{yh} \in \mathbb{R}^{y \times h}, \; \textbf{b}_{h} \in \mathbb{R}^{h}, \; \textbf{b}_{y} \in \mathbb{R}^{y}\)
\((2a)\) 式中的兩個矩陣 \(\mathbf{W}\) 可以合併:
\[ \begin{align*} \textbf{h}_{t} &= f\left(\textbf{W}_{hx}\textbf{x}_{t} + \textbf{W}_{hh}\textbf{h}_{t-1} + \textbf{b}_{h}\right) \\ & = f\left(\left(\textbf{W}_{hx}, \textbf{W}_{hh}\right) \begin{pmatrix} \textbf{x}_t \\ \textbf{h}_{t-1} \end{pmatrix} + \textbf{b}_{h}\right) \\ & = f\left(\textbf{W} \begin{pmatrix} \textbf{x}_t \\ \textbf{h}_{t-1} \end{pmatrix} + \textbf{b}_{h}\right) \end{align*} \]
注意到在計算時,每一 time step 中使用的引數 \(\textbf{W}, \; \textbf{b}\) 是一樣的,也就是說每個步驟的引數都是共享的,這是RNN的重要特點。
和普通的全連線層相比,RNN 除了輸入 \(\textbf{x}_t\) 外,還有輸入隱藏層上一節點 \(\mathbf{h}_{t-1}\) ,RNN 每一層的輸出就是這兩個輸入用矩陣 \(\textbf{W}_{hx}\) , \(\textbf{W}_{hh}\) 和啟用函式進行組合的結果。從 \((2a)\) 式可以看出 \(\textbf{x}_t\) 和 \(\mathbf{h}_{t-1}\) 都是與 \(\textbf{h}_h\) 全連線的,下圖形象展示了各個時間節點 RNN 隱藏層記憶的變化。隨著時間流逝,最初的藍色結點保留地越來越少,這意味著RNN對於長時記憶的困難。

Vanishing & Exploding Gradient Problems
RNN 對於長時記憶的困難主要來源於梯度爆炸 / 消失問題,下面進行說明。RNN 中 Loss 的計算圖示例:

總的 Loss 是每個 time step 的加和 : \(\mathcal{\large{L}} (\hat{\textbf{y}}, \textbf{y}) = \sum_{t = 1}^{T} \mathcal{ \large{L} }(\hat{\textbf{y}_t}, \textbf{y}_{t})\)
由 backpropagation through time (BPTT) 演算法,引數的梯度為:
\[ \frac{\partial \boldsymbol{\mathcal{L}}}{\partial \textbf{W}} = \sum_{t=1}^{T} \frac{\partial \boldsymbol{\mathcal{L}}_{t}}{\partial \textbf{W}} = \sum_{t=1}^{T} \frac{\partial \boldsymbol{\mathcal{L}}_t}{\partial \textbf{y}_{t}} \frac{\partial \textbf{y}_{t}}{\partial \textbf{h}_{t}} \overbrace{\frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{k}}}^{ \bigstar } \frac{\partial \textbf{h}_{k}}{\partial \textbf{W}} \]
其中 \(\frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{k}}\) 包含一系列 \(\text{Jacobian}\) 矩陣,
\[ \frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{k}} = \frac{\partial \textbf{h}_{t}}{\partial \textbf{h}_{t-1}} \frac{\partial \textbf{h}_{t-1}}{\partial \textbf{h}_{t-2}} \cdots \frac{\partial \textbf{h}_{k+1}}{\partial \textbf{h}_{k}} = \prod_{i=k+1}^{t} \frac{\partial \textbf{h}_{i}}{\partial \textbf{h}_{i-1}} \]
由於 RNN 中每個 time step 都是用相同的 \(\textbf{W}\) ,所以由 \((2a)\) 式可得:
\[ \prod_{i=k+1}^{t} \frac{\partial \textbf{h}_{i}}{\partial \textbf{h}_{i-1}} = \prod_{i=k+1}^{t} \textbf{W}^\top \text{diag} \left[ f'\left(\textbf{h}_{i-1}\right) \right] \]
由於 \(\textbf{W}_{hh} \in \mathbb{R}^{h \times h}\) 為方陣,對其進行特徵值分解:
\[ \mathbf{W} = \mathbf{V} \, \text{diag}(\boldsymbol{\lambda}) \, \mathbf{V}^{-1} \]
由於上式是連乘 \(\text{t}\) 次 \(\mathbf{W}\) :
\[ \mathbf{W}^t = (\mathbf{V} \, \text{diag}(\boldsymbol{\lambda}) \, \mathbf{V}^{-1})^t = \mathbf{V} \, \text{diag}(\boldsymbol{\lambda})^t \, \mathbf{V}^{-1} \]
連乘的次數多了之後,則若最大的特徵值 \(\lambda >1\),會產生梯度爆炸;
\(\lambda < 1\)
,則會產生梯度消失 。不論哪種情況,都會導致模型難以學到有用的模式。
下左圖顯示一個 time step 中 tanh 函式的計算結果,右圖顯示整個神經網路的計算結果,可以清楚地看到哪個區域最容易產生梯度爆炸/消失問題。

梯度爆炸的解決辦法:
(1) Truncated Backpropagation through time :每次只 BP 固定的 time step 數,類似於 mini-batch SGD。缺點是喪失了長距離記憶的能力。

(2) Clipping Gradients : 當梯度超過一定的 threshold 後,就進行 element-wise 的裁剪,該方法的缺點是又引入了一個新的引數 threshold。同時該方法也可視為一種基於瞬時梯度大小來自適應 learning rate 的方法:
\[ \text{if} \quad \lVert \textbf{g} \rVert \ge \text{threshold} \\[1ex] \textbf{g} \leftarrow \frac{\text{threshold}}{\lVert \textbf{g} \rVert} \textbf{g} \]

梯度消失的解決辦法
(1) 使用 LSTM、GRU等升級版 RNN,使用各種 gates 控制資訊的流通。
(2) 在這篇論文 ( https://arxiv.org/pdf/1602.06662.pdf ) 中提出將權重矩陣 \(\textbf{W}\) 初始化為正交矩陣。正交矩陣有如下性質: \(A^T A =A A^T = I, \; A^T = A^{-1}\) , 正交矩陣的特徵值的絕對值為 \(\text{1}\) 。證明如下, 對矩陣 \(A\) 有:
\[ \begin{align*} & A \mathbf{v} = \lambda \mathbf{v} \\[1ex] ||A \mathbf{v}||^2& = (A \mathbf{v})^\text{T} (A \mathbf{v}) \\ &= \mathbf{v}^\text{T}A ^{\text{T}}A \mathbf{v} \\ & = \mathbf{v}^{\text{T}}\mathbf{v} \\ & = ||\mathbf{v}||^2 \\ & = |\lambda|^2 ||\mathbf{v}||^2 \end{align*} \]
由於 \(\mathbf{v}\) 為特徵向量, \(\mathbf{v} \neq 0\) ,所以 \(|\lambda| = 1\),這樣連乘之後
\(\lambda^t\)
不會出現越來越小的情況。
(3) 反轉輸入序列。像在機器翻譯中使用 seq2seq 模型,若使用正常序列輸入,則輸入序列的第一個詞和輸出序列的第一個詞相距較遠,難以學到長期依賴。將輸入序列反向後,輸入序列的第一個詞就會和輸出序列的第一個詞非常接近,二者的相互關係也就比較容易學習了。這樣模型可以先學前幾個詞的短期依賴,再學後面詞的長期依賴關係。見下圖正常輸入順序是 \(|\text{ABC}|\) ,反向是 \(|\text{CBA}|\) ,則 \(\text{A}\) 與第一個輸出詞 \(\text{W}\) 接近:

LSTM
雖然 Vanilla RNN 理論上可以建立長時間間隔狀態之間的依賴關係,但由於梯度爆炸或消失問題,實際上只能學到短期依賴關係。為了學到長期依賴關係,LSTM 中引入了門控機制來控制資訊的累計速度,包括有選擇地加入新的資訊,並有選擇地遺忘之前累計的資訊,整個 LSTM 單元結構如下圖所示:

\[ \begin{align} \text{input gate}&: \quad \textbf{i}_t = \sigma(\textbf{W}_i\textbf{x}_t + \textbf{U}_i\textbf{h}_{t-1} + \textbf{b}_i)\tag{1} \\ \text{forget gate}&: \quad \textbf{f}_t = \sigma(\textbf{W}_f\textbf{x}_t + \textbf{U}_f\textbf{h}_{t-1} + \textbf{b}_f) \tag{2}\\ \text{output gate}&: \quad \textbf{o}_t = \sigma(\textbf{W}_o\textbf{x}_t + \textbf{U}_o\textbf{h}_{t-1} + \textbf{b}_o) \tag{3}\\ \text{new memory cell}&: \quad \tilde{\textbf{c}}_t = \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c\textbf{h}_{t-1} + \textbf{b}_c) \tag{4}\\ \text{final memory cell}& : \quad \textbf{c}_t = \textbf{f}_t \odot \textbf{c}_{t-1} + \textbf{i}_t \odot \tilde{\textbf{c}}_t \tag{5}\\ \text{final hidden state} &: \quad \textbf{h}_t= \textbf{o}_t \odot \text{tanh}(\textbf{c}_t) \tag{6} \end{align} \]
式 $(1) \sim (4) $ 的輸入都一樣,因而可以合併:
\[ \begin{pmatrix} \textbf{i}_t \\ \textbf{f}_{t} \\ \textbf{o}_t \\ \tilde{\textbf{c}}_t \end{pmatrix} = \begin{pmatrix} \sigma \\ \sigma \\ \sigma \\ \text{tanh} \end{pmatrix} \left(\textbf{W} \begin{bmatrix} \textbf{x}_t \\ \textbf{h}_{t-1} \end{bmatrix} + \textbf{b} \right) \]
$\tilde{\textbf{c}}_t $ 為時刻 t 的候選狀態, \(\textbf{i}_t\) 控制 \(\tilde{\textbf{c}}_t\) 中有多少新資訊需要儲存, \(\textbf{f}_{t}\) 控制上一時刻的內部狀態 \(\textbf{c}_{t-1}\) 需要遺忘多少資訊, \(\textbf{o}_t\) 控制當前時刻的內部狀態 \(\textbf{c}_t\) 有多少資訊需要輸出給外部狀態 \(\textbf{h}_t\) 。
下表顯示 forget gate 和 input gate 的關係,可以看出 forget gate 其實更應該被稱為 “remember gate”, 因為其開啟時之前的記憶資訊 \(\textbf{c}_{t-1}\) 才會被保留,關閉時則會遺忘所有:
forget gate | input gate | result |
---|---|---|
1 | 0 | 保留上一時刻的狀態 \(\textbf{c}_{t-1}\) |
1 | 1 | 保留上一時刻 \(\textbf{c}_{t-1}\) 和新增新資訊 \(\tilde{\textbf{c}}_t\) |
0 | 1 | 清空歷史資訊,引入新資訊 \(\tilde{\textbf{c}}_t\) |
0 | 0 | 清空所有新舊資訊 |
對比 Vanilla RNN,可以發現在時刻 t,Vanilla RNN 通過 \(\textbf{h}_t\) 來儲存和傳遞資訊,上文已分析瞭如果時間間隔較大容易產生梯度消失的問題。 LSTM 則通過記憶單元 \(\textbf{c}_t\) 來傳遞資訊,通過 \(\textbf{i}_t\) 和 \(\textbf{f}_{t}\) 的調控, \(\textbf{c}_t\) 可以在 t 時刻捕捉到某個關鍵資訊,並有能力將此關鍵資訊儲存一定的時間間隔。
原始的 LSTM 中是沒有 forget gate 的,即:
\[ \textbf{c}_t = \textbf{c}_{t-1} + \textbf{i}_t \odot \tilde{\textbf{c}}_t \]
這樣 \(\frac{\partial \textbf{c}_t}{\partial \textbf{c}_{t-1}}\) 恆為 \(\text{1}\) 。但是這樣 \(\textbf{c}_t\) 會不斷增大,容易飽和從而降低模型效能。後來引入了 forget gate ,則梯度變為 \(\textbf{f}_{t}\) ,事實上連乘多個 \(\textbf{f}_{t} \in (0,1)\)同樣會導致梯度消失,但是 LSTM 的一個初始化技巧就是將 forget gate 的 bias 置為正數(例如 1 或者 5,如 tensorflow 中的預設值就是
\(1.0\)
),這樣一來模型剛開始訓練時 forget gate 的值都接近 1,不會發生梯度消失 (反之若 forget gate 的初始值過小則意味著前一時刻的大部分資訊都丟失了,這樣很難捕捉到長距離依賴關係)。 隨著訓練過程的進行,forget gate 就不再恆為 1 了。不過,一個訓好的模型裡各個 gate 值往往不是在 [0, 1] 這個區間裡,而是要麼 0 要麼 1,很少有類似 0.5 這樣的中間值,其實相當於一個二元的開關。假如在某個序列裡,forget gate 全是 1,那麼梯度不會消失;某一個 forget gate 是 0,模型選擇遺忘上一時刻的資訊。
LSTM 的一種變體增加 peephole 連線,這樣三個 gate 不僅依賴於 \(\textbf{x}_t\) 和 \(\textbf{h}_{t-1}\) ,也依賴於記憶單元 \(\textbf{c}\) :
\[ \begin{align*} \text{input gate}&: \quad \textbf{i}_t = \sigma(\textbf{W}_i\textbf{x}_t + \textbf{U}_i\textbf{h}_{t-1} + \textbf{V}_i\textbf{c}_{t-1} + \textbf{b}_i) \\ \text{forget gate}&: \quad \textbf{f}_t = \sigma(\textbf{W}_f\textbf{x}_t + \textbf{U}_f\textbf{h}_{t-1} + \textbf{V}_f\textbf{c}_{t-1} +\textbf{b}_f) \\ \text{output gate}&: \quad \textbf{o}_t = \sigma(\textbf{W}_o\textbf{x}_t + \textbf{U}_o\textbf{h}_{t-1} + \textbf{V}_o\textbf{c}_{t} +\textbf{b}_o) \\ \end{align*} \]
注意 input gate 和 forget gate 連線的是 \(\textbf{c}_{t-1}\) ,而 output gate 連線的是 \(\textbf{c}_t\) 。下圖來自 《LSTM: A Search Space Odyssey》 ,標註了 peephole 連線的樣貌。

GRU
相比於 Vanilla RNN (每個 time step 有一個輸入 \(\textbf{x}_t\) ),從上面的 \((1) \sim (4)\) 式可以看出 一個 LSTM 單元有四個輸入 (如下圖,不考慮 peephole) ,因而引數是 Vanilla RNN 的四倍,帶來的結果是訓練起來很慢,因而在2014年 Cho 等人提出了 GRU ,對 LSTM 進行了簡化,在不影響效果的前提下加快了訓練速度。

\(\large\scr{LSTM:}\)
\[ \normalsize \begin{align} \text{input gate}&: \quad \textbf{i}_t = \sigma(\textbf{W}_i\textbf{x}_t + \textbf{U}_i\textbf{h}_{t-1} + \textbf{b}_i)\tag{1} \\ \text{forget gate}&: \quad \textbf{f}_t = \sigma(\textbf{W}_f\textbf{x}_t + \textbf{U}_f\textbf{h}_{t-1} + \textbf{b}_f) \tag{2}\\ \text{output gate}&: \quad \textbf{o}_t = \sigma(\textbf{W}_o\textbf{x}_t + \textbf{U}_o\textbf{h}_{t-1} + \textbf{b}_o) \tag{3}\\ \text{new memory cell}&: \quad \tilde{\textbf{c}}_t = \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c\textbf{h}_{t-1} + \textbf{b}_c) \tag{4}\\ \text{final memory cell}& : \quad \textbf{c}_t = \textbf{f}_t \odot \textbf{c}_{t-1} + \textbf{i}_t \odot \tilde{\textbf{c}}_t \tag{5}\\ \text{final hidden state} &: \quad \textbf{h}_t= \textbf{o}_t \odot \text{tanh}(\textbf{c}_t) \tag{6} \end{align} \]
在式 \((5)\) 中 forget gate 和 input gate 是互補關係,因而比較冗餘,GRU 將其合併為一個 update gate。同時 GRU 也不引入額外的記憶單元 (LSTM 中的 \(\textbf{c}\) ) ,而是直接在當前狀態 \(\textbf{h}_t\)和歷史狀態
\(\textbf{h}_{t-1}\)
之間建立線性依賴關係。

\(\large\scr{GRU:}\)
\[ \normalsize \begin{align} \text{reset gate}&: \quad \textbf{r}_t = \sigma(\textbf{W}_r\textbf{x}_t + \textbf{U}_r\textbf{h}_{t-1} + \textbf{b}_r)\tag{7} \\ \text{update gate}&: \quad \textbf{z}_t = \sigma(\textbf{W}_z\textbf{x}_t + \textbf{U}_z\textbf{h}_{t-1} + \textbf{b}_z)\tag{8} \\ \text{new memory cell}&: \quad \tilde{\textbf{h}}_t = \text{tanh}(\textbf{W}_h\textbf{x}_t + \textbf{r}_t \odot (\textbf{U}_h\textbf{h}_{t-1}) + \textbf{b}_h) \tag{9}\\ \text{final hidden state}&: \quad \textbf{h}_t = \textbf{z}_t \odot \textbf{h}_{t-1} + (1 - \textbf{z}_t) \odot \tilde{\textbf{h}}_t \tag{10} \end{align} \]
$ \tilde{\textbf{h}}_t $ 為時刻 t 的候選狀態, \(\textbf{r}_t\) 控制 $ \tilde{\textbf{h}}_t $ 有多少依賴於上一時刻的狀態 \(\textbf{h}_{t-1}\) ,如果 \(\textbf{r}_t = 1\) ,則式 \((9)\) 與 Vanilla RNN 一致,對於短依賴的 GRU 單元,reset gate 通常會更新頻繁。 \(\textbf{z}_t\) 控制當前的內部狀態 \(\textbf{h}_t\) 中有多少來自於上一時刻的 \(\textbf{h}_{t-1}\) 。如果 \(\textbf{z}_t = 1\),則會每步都傳遞同樣的資訊,和當前輸入
\(\textbf{x}_t\)
無關。
另一方面看, \(\textbf{r}_t\) 與 LSTM 中的 \(\textbf{o}_t\) 角色有些類似,因為將上面的 \((6)\) 式代入 \((4)\) 式可以得到:
\[ \begin{align*} \tilde{\textbf{c}}_t &= \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c\textbf{h}_{t-1} + \textbf{b}_c) \\ \textbf{h}_t &= \textbf{o}_t \odot \text{tanh}(\textbf{c}_t) \end{align*} \quad \Longrightarrow \quad \tilde{\textbf{c}}_t = \text{tanh}(\textbf{W}_c\textbf{x}_t + \textbf{U}_c \left(\textbf{o}_{t-1} \odot \text{tanh}(\textbf{c}_{t-1})\right) + \textbf{b}_c) \]
最後是 cs224n 中提出的 RNN 訓練 tips:

/