1. 程式人生 > >RNN 與 LSTM 的原理詳解

RNN 與 LSTM 的原理詳解

本文主要講解了 RNN 和 LSTM 的結構、前饋、反饋的原理,參考了https://www.jianshu.com/p/f3bde26febed/https://www.jianshu.com/p/9dc9f41f0b29https://blog.csdn.net/zhaojc1995/article/details/80572098等文章,並糾正了公式的錯誤、更新了無效的論文連結。

RNN(Recurrent Neural Network)是一類用於處理序列資料的神經網路。

什麼是序列呢?序列是一串有順序的資料,比如某一條資料為 [

x 1 , x 2 , x 3
, x 4 ] [x_1, x_2, x_3, x_4] ,其中每個元素可以是一個字元、一個單詞、一個向量,甚至是一個聲音。比如:

  • 自然語言處理問題。 x
    1 x_1
    可以看做是第一個單詞, x 2 x_2 可以看做是第二個單詞,依次類推。
  • 語音處理。此時,每個元素是每幀的聲音訊號。
  • 時間序列問題。例如每天的股票價格等。

RNN 處理這種序列資料,在結構上具有天然的優勢(相對於普通的神經網路而言,如全連線、CNN等)。

RNN 的結構

我們從基礎的神經網路中知道,神經網路包含輸入層、隱層、輸出層,通過啟用函式控制輸出,層與層之間通過權值連線。啟用函式是事先確定好的,那麼神經網路模型通過訓練“學“到的東西就蘊含在“權值“中。單層的神經網路如圖:
在這裡插入圖片描述

其中,x為輸入,W為權重矩陣,b為偏置,f為啟用函式,如 sigmoid 等,y 為輸出。這樣,就建立了輸入與輸出之間的關聯。

基礎的神經網路只在層與層之間建立了權連線,RNN最大的不同之處就是在層之間的神經元之間也建立的權連線。如圖。

RNN的結構及展開

這是一個標準的RNN結構圖,圖中每個箭頭代表做一次變換,也就是說箭頭連線帶有權值。左側是摺疊起來的樣子,右側是展開的樣子,左側中h旁邊的箭頭代表此結構中的“迴圈“體現在隱層。

在展開結構中我們可以觀察到,在標準的RNN結構中,隱層的神經元之間也是帶有權值的。也就是說,隨著序列的不斷推進,前面的隱層將會影響後面的隱層。圖中O代表輸出,y代表樣本給出的確定值,L代表損失函式,我們可以看到,“損失“也是隨著序列的推薦而不斷積累的。

除上述特點之外,標準RNN的還有以下特點:

  • 權值共享,圖中的W全是相同的,U和V也一樣。
  • 每一個輸入值都只與它本身的那條路線建立權連線,不會和別的神經元連線。

以上是RNN的標準結構,屬於多輸入多輸出,並且輸入與輸出的個數是相同的,即每次輸入都會對應一個輸出。

然而在實際中這一種結構並不能解決所有問題,常見的變種有:

1、多輸入單輸出

有的時候,我們要處理的問題輸入是一個序列,輸出是一個單獨的值而不是序列,應該怎樣建模呢?實際上,我們只在最後一個h上進行輸出變換就可以了:
在這裡插入圖片描述

這種結構通常用來處理序列分類問題。如輸入一段文字判別它所屬的類別,輸入一個句子判斷其情感傾向,輸入一段視訊並判斷它的類別等等。

2、單輸入多輸出

輸入不是序列而輸出為序列的情況怎麼處理?我們可以只在序列開始進行輸入計算,其餘只需要隱層狀態進行傳遞。
單輸入多輸出(只輸入第一次迴圈)

還有一種結構是把輸入資訊X作為每個階段的輸入:
單輸入多輸出(每次迴圈都輸入同一個值)

這種單輸入多輸出的結構可以處理的問題有:

  • 從影象生成文字(image caption),此時輸入的X就是影象的特徵,而輸出的y序列就是一段句子
  • 從類別生成語音或音樂等

3、多輸入多輸出(輸入和輸出個數不同)

實際中,還有另外一種多輸入多輸出的結構,其輸入與輸出並不是一一對應的,如圖:
多輸入多輸出(輸入和輸出個數不同)

這種結構又叫Encoder-Decoder模型,也可以稱之為Seq2Seq模型。

Encoder-Decoder結構先將輸入資料編碼成一個上下文向量c。得到c有多種方式,最簡單的方法就是把Encoder的最後一個隱狀態賦值給c,還可以對最後的隱狀態做一個變換得到c,也可以對所有的隱狀態做變換。

拿到c之後,就用另一個RNN網路對其進行解碼,這部分RNN網路被稱為Decoder。具體做法就是將c當做之前的初始狀態h0輸入到Decoder中。

還有另外一種 Decoder,是將c當做每一步的輸入:
多輸入多輸出(輸入和輸出個數不同)

由於這種Encoder-Decoder結構不限制輸入和輸出的序列長度,因此應用的範圍非常廣泛,比如:

  • 機器翻譯。Encoder-Decoder的最經典應用,事實上這一結構就是在機器翻譯領域最先提出的
  • 文字摘要。輸入是一段文字序列,輸出是這段文字序列的摘要序列。
  • 閱讀理解。將輸入的文章和問題分別編碼,再對其進行解碼得到問題的答案。
  • 語音識別。輸入是語音訊號序列,輸出是文字序列。

RNN的前向輸出流程

下面對多輸入多輸出(一一對應)的經典結構作分析:
RNN的結構及展開
其中,x是輸入,h是隱層單元,o為輸出,L為損失函式,y為訓練集的標籤。這些元素右上角帶的t代表t時刻的狀態,其中需要注意的是,隱層單元h在t時刻的表現不僅由此刻的輸入決定,還受t時刻之前時刻的影響。V、W、U是權值,同一型別的權連線權值相同。

前向傳播演算法其實非常簡單,對於t時刻,隱層單元為:
h ( t ) = f ( U x ( t ) + W h ( t 1 ) + b ) h^{(t)}=f(Ux^{(t)}+Wh^{(t-1)}+b)

其中,f 為啟用函式,如 sigmoid、tanh 等,b 為偏置。

t時刻的輸出為:
o ( t ) = V h ( t ) + c o^{(t)}=Vh^{(t)}+c

RNN的訓練方法

BPTT(back-propagation through time)演算法是常用的訓練RNN的方法,其實本質還是BP演算法,只不過RNN處理時間序列資料,所以要基於時間反向傳播,故叫隨時間反向傳播。BPTT的中心思想和BP演算法相同,沿著需要優化的引數的負梯度方向不斷尋找更優的點直至收斂。綜上所述,BPTT演算法本質還是BP演算法,BP演算法本質還是梯度下降法,那麼求各個引數的梯度便成了此演算法的核心。

RNN的結構及展開

再次拿出這個結構圖觀察,需要尋優的引數有三個,分別是U、V、W。與BP演算法不同的是,其中W和U兩個引數的尋優過程需要追溯之前的歷史資料,引數V相對簡單隻需關注目前,那麼我們就來先求解引數V的偏導數。
L ( t ) V = L ( t ) o ( t ) o ( t ) V \frac{\partial L^{(t)}}{\partial V}=\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V}
RNN的損失也是會隨著時間累加的,所以需要求出所有時刻的偏導然後求和:
L = t = 1 n L ( t ) L=\sum_{t=1}^n L^{(t)}
L V = t = 1 n L ( t ) o ( t ) o ( t ) V \frac{\partial L}{\partial V}=\sum_{t=1}^n\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V}

W和U的偏導的求解由於需要涉及到歷史資料,其偏導求起來相對複雜,我們先假設只有三個時刻,那麼在第三個時刻 L對W的偏導數為:
L ( 3 ) W = L ( 3 ) o ( 3 ) o ( 3 ) h ( 3 ) h ( 3 ) W + L ( 3 ) o ( 3 ) o ( 3 ) h ( 3 ) h ( 3 ) h ( 2 ) h ( 2 ) W + L ( 3 ) o ( 3 ) o ( 3 ) h ( 3 ) h ( 3 ) h ( 2 ) h ( 2 ) h ( 1 ) h ( 1 ) W \frac{\partial L^{(3)}}{\partial W}=\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial W}+\frac{\partial L^{(3)}}{\partial o^{(3)}}\frac{\partial o^{(3)}}{\partial h^{(3)}}\frac{\partial h^{(3)}}{\partial h^{(2)}}\frac{\partial h^{(2)}}{\partial h^{(1)}}\frac{\partial h^{(1)}}{\partial W}