1. 程式人生 > >從迴圈神經網路(RNN)到LSTM網路

從迴圈神經網路(RNN)到LSTM網路

  通常,資料的存在形式有語音、文字、影象、視訊等。因為我的研究方向主要是影象識別,所以很少用有“記憶性”的深度網路。懷著對迴圈神經網路的興趣,在看懂了有關它的理論後,我又看了Github上提供的tensorflow實現,覺得收穫很大,故在這裡把我的理解記錄下來,也希望對大家能有所幫助。本文將主要介紹RNN相關的理論,並引出LSTM網路結構(關於對tensorflow實現細節的理解,有時間的話,在下一篇博文中做介紹)。

迴圈神經網路

  RNN,也稱作迴圈神經網路(還有一種深度網路,稱作遞迴神經網路,讀者要區別對待)。因為這種網路有“記憶性”,所以主要是應用在自然語言處理(NLP)和語音領域。與傳統的Neural network不同,RNN能利用上”序列資訊”。從理論上講,它可以利用任意長序列的資訊,但由於該網路結構存在“消失梯度”問題,所以在實際應用中,它只能回溯利用與它接近的time steps上的資訊。

1. 網路結構

  常見的神經網路結構有卷積網路、迴圈網路和遞迴網路,棧式自編碼器和玻爾茲曼機也可以看做是特殊的卷積網路,區別是它們的損失函式定義成均方誤差函式。遞迴網路類似於資料結構中的樹形結構,且其每層之間會有共享引數。而最為常用的迴圈神經網路,它的每層的結構相同,且每層之間引數完全共享。RNN的縮圖和展開圖如下,

  儘管RNN的網路結構看上去與常見的前饋網路不同,但是它的展開圖中資訊流向也是確定的,沒有環流,所以也屬於forward network,故也可以使用反向傳播(back propagation)演算法來求解引數的梯度。另外,在RNN網路中,可以有單輸入、多輸入、單輸出、多輸出,視具體任務而定。

2. 損失函式

  在輸出層為二分類或者softmax多分類的深度網路中,代價函式通常選擇交叉熵(cross entropy)損失函式,前面的博文中證明過,在分類問題中,交叉熵函式的本質就是似然損失函式。儘管RNN的網路結構與分類網路不同,但是損失函式也是有相似之處的。
  假設我們採用RNN網路構建“語言模型”,“語言模型”其實就是看“一句話說出來是不是順口”,可以應用在機器翻譯、語音識別領域,從若干候選結果中挑一個更加靠譜的結果。通常每個sentence長度不一樣,每一個word作為一個訓練樣例,一個sentence作為一個Minibatch,記sentence的長度為T。為了更好地理解語言模型中損失函式的定義形式,這裡做一些推導,根據全概率公式,則一句話是“自然化的語句”的概率為

p(w1,w2,...,wT)=p(w1)×p(w2|w1)×...×p(wT|w1,w2,...,wT1)   所以語言模型的目標就是最大化 P(w1,w2,...,wT) 。而損失函式通常為最小化問題,所以可以定義 Loss(w1,w2,...,wT|θ)=logP(w1,w2,...,wT|θ)   那麼公式展開可得 Loss(w1,w2,...,wT|θ)=(logp(w1)+logp(w2|w1)+...+logp(wT|w1,w2,...,wT1))   展開式中的每一項為一個softmax分類模型,類別數為所採用的詞庫大小(vocabulary size),相信大家此刻應該就明白了,為什麼使用RNN網路解決語言模型時,輸入序列和輸出序列錯了一個位置了。

3. 梯度求解

  在訓練任何深度網路模型時,求解損失函式關於模型引數的梯度,應該算是最為核心的一步了。在RNN模型訓練時,採用的是BPTT(back propagation through time)演算法,這個演算法其實實質上就是樸素的BP演算法,也是採用的“鏈式法則”求解引數梯度,唯一的不同在於每一個time step上引數共享。從數學的角度來講,BP演算法就是一個單變數求導過程,而BPTT演算法就是一個複合函式求導過程。接下來以損失函式展開式中的第3項為例,推導其關於網路引數U、W、V的梯度表示式(總損失的梯度則是各項相加的過程而已)。
  為了簡化符號表示,記 E3=logp(w3|w1,w2) ,則根據RNN的展開圖可得,

s3=tanh(U×x3+W×s2)  s2=tanh(U×x2+W×s1)s1=tanh(U×x1+W×s0)  s0=tanh(U×x0+W×s1)(1)

  所以,

s3W=s3W1+s3s2×s2Ws2W=s2W1+s2s1×s1Ws1W=s1W0+s1s0×s0Ws0W=s0W1(2)

  說明一下,為了更好地體現複合函式求導的思想,公式(2)中引入了變數 W1 ,可以把 W1 看作關於W的函式,即 W1=W 。另外,因為 s1 表示RNN網路的初始狀態,為一個常數向量,所以公式(2)中第4個表示式展開後只有一項。所以由公式(2)可得,

s3W=s3W1+s3s2×s2W1+s3s2×s2s1×s1W1+s3s2×s2s1×s1s0×s0W1(3)

  簡化得下式,

s3W=s3W1+s3s2×s2W1+s3S1×s1W1+s3s0×s0W1(4)

  繼續簡化得下式,

s3W=i=03s3si×siW1(5)

3.1 E3 關於引數V的偏導數

  記t=3時刻的softmax神經元的輸入為 a3 ,輸出為 y3 ,網路的真實標籤為 y(1)3 。根據函式求導的“鏈式法則”,所以有下式成立,

E3V=E3a3×a3V=(y(1)3y3)s3(6)

3.2 E3 關於引數W的偏導數

  關於引數W的偏導數,就要使用到上面關於複合函式的推導過程了,記 zi 為t=i時刻隱藏層神經元的輸入,則具體的表示式簡化過程如下,

E3W=E3s3×s3W=E3a3×a3s3×s3W=k=03<