1. 程式人生 > >【轉】人人都能看懂的LSTM

【轉】人人都能看懂的LSTM

轉自:https://zhuanlan.zhihu.com/p/32085405

這是在看了臺大李巨集毅教授的深度學習視訊之後的一點總結和感想。看完介紹的第一部分RNN尤其LSTM的介紹之後,整個人醍醐灌頂。本篇部落格就是對視訊的一些記錄加上了一些個人的思考。

0. 從RNN說起

迴圈神經網路(Recurrent Neural Network,RNN)是一種用於處理序列資料的神經網路。相比一般的神經網路來說,他能夠處理序列變化的資料。比如某個單詞的意思會因為上文提到的內容不同而有不同的含義,RNN就能夠很好地解決這類問題。

1. 普通RNN

先簡單介紹一下一般的RNN。

其主要形式如下圖所示(圖片均來自臺大李巨集毅教授的PPT):

這裡:

x 為當前狀態下資料的輸入, h 表示接收到的上一個節點的輸入。

y 為當前節點狀態下的輸出,而 h' 為傳遞到下一個節點的輸出。

通過上圖的公式可以看到,輸出 h'xh 的值都相關。

y 則常常使用 h' 投入到一個線性層(主要是進行維度對映)然後使用softmax進行分類得到需要的資料。

對這裡的y如何通過 h' 計算得到往往看具體模型的使用方式。

通過序列形式的輸入,我們能夠得到如下形式的RNN。

2. LSTM

2.1 什麼是LSTM

長短期記憶(Long short-term memory, LSTM)是一種特殊的RNN,主要是為了解決長序列訓練過程中的梯度消失和梯度爆炸問題。簡單來說,就是相比普通的RNN,LSTM能夠在更長的序列中有更好的表現。

LSTM結構(圖右)和普通RNN的主要輸入輸出區別如下所示。

相比RNN只有一個傳遞狀態 h^t ,LSTM有兩個傳輸狀態,一個 c^t (cell state),和一個 h^t (hidden state)。(Tips:RNN中的 h^t 對於LSTM中的 c^t

其中對於傳遞下去的 c^t 改變得很慢,通常輸出的 c^t 是上一個狀態傳過來的 c^{t-1} 加上一些數值。

h^t 則在不同節點下往往會有很大的區別。

2.2 深入LSTM結構

下面具體對LSTM的內部結構來進行剖析。

首先使用LSTM的當前輸入 x^t 和上一個狀態傳遞下來的 h^{t-1} 拼接訓練得到四個狀態。

其中, z^fz^iz^o 是由拼接向量乘以權重矩陣之後,再通過一個 sigmoid 啟用函式轉換成0到1之間的數值,來作為一種門控狀態。而 z

則是將結果通過一個 tanh 啟用函式將轉換成-1到1之間的值(這裡使用 tanh 是因為這裡是將其做為輸入資料,而不是門控訊號)。

下面開始進一步介紹這四個狀態在LSTM內部的使用。(敲黑板)

\odot 是Hadamard Product,也就是操作矩陣中對應的元素相乘,因此要求兩個相乘矩陣是同型的。 \oplus 則代表進行矩陣加法。

LSTM內部主要有三個階段:

1. 忘記階段。這個階段主要是對上一個節點傳進來的輸入進行選擇性忘記。簡單來說就是會 “忘記不重要的,記住重要的”。

具體來說是通過計算得到的 z^f (f表示forget)來作為忘記門控,來控制上一個狀態的 c^{t-1} 哪些需要留哪些需要忘。

2. 選擇記憶階段。這個階段將這個階段的輸入有選擇性地進行“記憶”。主要是會對輸入 x^t 進行選擇記憶。哪些重要則著重記錄下來,哪些不重要,則少記一些。當前的輸入內容由前面計算得到的 z 表示。而選擇的門控訊號則是由 z^i (i代表information)來進行控制。

將上面兩步得到的結果相加,即可得到傳輸給下一個狀態的 c^t 。也就是上圖中的第一個公式。

3. 輸出階段。這個階段將決定哪些將會被當成當前狀態的輸出。主要是通過 z^o 來進行控制的。並且還對上一階段得到的 c^o 進行了放縮(通過一個tanh啟用函式進行變化)。

與普通RNN類似,輸出 y^t 往往最終也是通過 h^t 變化得到。

3. 總結

以上,就是LSTM的內部結構。通過門控狀態來控制傳輸狀態,記住需要長時間記憶的,忘記不重要的資訊;而不像普通的RNN那樣只能夠“呆萌”地僅有一種記憶疊加方式。對很多需要“長期記憶”的任務來說,尤其好用。

但也因為引入了很多內容,導致引數變多,也使得訓練難度加大了很多。因此很多時候我們往往會使用效果和LSTM相當但引數更少的GRU來構建大訓練量的模型。

對於GRU我在下面的文章中進行了相關介紹,有興趣的同學可以進去看看。