1. 程式人生 > >RNN與LSTM

RNN與LSTM

前言

迴圈神經網路,迴圈神經網路與卷積神經網路有很大的不同。就是有“記憶暫存”功能,可以把過去的輸入內容產生的影響量化後與當前時間輸入一起反應到網路中參與訓練。

RNN理解

個人理解,RNN還是在模仿人類。在模仿人類的學習的過程。根據前言所述,當前輸入對未來可能會有影響。那麼當前的輸出應該也跟之前的輸入和當前的輸入有關係。
舉個生活中的例子:例子背景就是在學校唸書讀小學。在這裡把神經元比做自己,輸入Xt比做當前年紀學習的知識,那麼X就分為X1,X2、X3、X4、X5、X6。(小學有6個年級)。輸出Y比做是一個學年的期末考試。那麼Y也有Y1,Y2、Y3、Y4、Y5、Y6。就假裝我們上三年級,那麼輸入就是X3,要想知道Y3。意思就是說我們學完了三年級要去考試了,那麼考試的結果不應該只與我們三年級學習的知識有關,還應該跟二年級乃至一年級有關係。接著這個故事背景來看一下具體的RNN演算法。

RNN與BPTT

RNN一般是用來作NLP的,因為語言相關的問題往往跟前後輸入有關,比如說輸入法的自動補充。我們簡單輸入幾個字,輸入法會幫我們猜測下一個字可能是什麼。

RNN網路結構

直接看圖
image.png
這裡的W、U、V是權重向量。而且整個RNN中共享這三個變數。X是輸入、Y為輸出,S為記憶神經元,下標都為時間戳。我們很容易看出輸出Yt與前一個記憶神經元和當前輸入有關係。
image.png
其中的 f()是一個啟用函式,一般用tanh啟用函式。這個函式可以把連續函式離散到0~1的值。
仔細觀察,看到Yt只與記憶體有關。因為這個輸出可能是每個詞的概率,所以這裡一般使用sigmax啟用函式(sigmax啟用函式有個特性,輸出的種類概率的和是1)
image.png

誤差傳遞

BPTT演算法跟BP神經網路是一個思路,就是先正向傳遞然後得到預測值與真實值計算誤差然後進行反向傳播進行調整!(細節不做贅述)這裡要明確的是,RNN神經網路是共享三個引數變數(W,V,U)。而且誤差跟時間 有關係。
image.png
簡化一下就是這樣的。其中E就是誤差。
用E2做樣例分析(使用平方損失函式)
image.png

再重複一次
W 是記憶體的權重。
U 是輸入的權重。
V 是輸出的權重。
那麼求 輸入部分的權重偏導。
image.png
這好像並不難,每次的loss跟輸入有關係。

但是我們再看看對W求偏導,簡直是個災難。
image.png

這裡面竟然有個S1 ,因為S1=f(WS0+UX1)這個公式得來的,S1還有W,根據鏈式法則還要對S1求導。這樣看求E2的loss還好,但是要對E1000求導呢。那就。。
需要求出1000前面所有的偏W,這簡直會爆炸的。。
所以,RNN在理論上說得通,在實際中效果肯定不是很理想的。

LSTM演算法

LSTM模型

其中A為記憶體 ,這裡稱為Cell(細胞)。放大一點可以看到內部結構。
image.png
這與電子電路中的電路板有點像,有點像閘電路那種東西。所以也有地方稱這是LSTM 是忘記門,其實這個描述很形象,LSTM其實就是每次傳播的時候就會忘記一些沒有用的東西。
還是拿考試做比方,不是學的每個知識點都要考,我們腦容量有限,可以忘掉一些知識從而學習到更多的知識來應付考試(應試教育都這樣)。構成忘記門最重要的單元就是:
image.png
其中黃色的就是sigmoid啟用函式,sigmoid會產生一個概率值s, 通過概率值確定留下多少。比如s=0.8 那麼輸出留下輸入的0.8 。如果s=0就說這部分輸入不要。如果是s=1代表十分重要,要全部保留。(X 就是乘法,+ 是加法)

LSTM傳輸過程

Step 1

淺顯的理解就是根據當前的輸入,和上一個時序的輸出選擇性忘記一點東西。例子就是,根據初一的期末考試和初二學習的,忘記一點初一的內容。
image.png
其中引數:
ht-1 就是上一次的輸出。
xt 就是當前的輸入
[ht-1,xt] 就是矩陣的合併。
bf 是偏置項
Wf 是權重項
ft 輸出的概率值

Step 2

淺顯的理解就是當前輸入的資訊什麼重要,我們決定留下什麼。例子就是,根據初一的期末考試和初二學習的,選擇性的忘記一點初二的內容。
image.png
引數:
不用解釋了吧。
其中it是一個概率值,Ct是一個矩陣。

Step 3

淺顯的理解就是,把上一個記憶體中剩下的和新資訊剩下的相加形成一個新的記憶體 就是初一的重點+初二的重點 就可以去參加初二的考試了。
image.png

Step 4

淺顯的理解就是要根據當前的狀態在記憶體中選擇要輸出的內容。 要考試了,應該從腦海中拿出什麼呢,還是得看考試問題問什麼吧。
image.png
其中的ht就是輸出,同樣的作為下一層的輸入。

總結

反向傳播比較麻煩,還沒有弄清楚,這裡給出一篇別人理解的部落格。https://www.jianshu.com/p/4e285112b988。不過現在的tensorflow 已經幫忙弄好了,用不是問題了。這只是對LSTM淺顯的理解。路還很長,加油。