1. 程式人生 > >NLP 相關演算法 LSTM 演算法流程

NLP 相關演算法 LSTM 演算法流程

LSTM希望通過改進的RNN內部計算方法來應對普通RNN經常面臨的梯度消失和梯度爆炸。基本思路是通過改變逆向傳播求導時單純的偏導連乘關係,從而避免較小的sigmoid或relu啟用函式偏導連乘現象。
RNN網路unfold以後,將按時間t展開為若干個結構相同的計算單元,每個計算單元在利用當前時間的輸入以外,還需要之前時間的輸出。以下將展示每個計算單元的內部計算流程,假設當前的計算單元對應時間為t。
每個計算單元內由input gateforget gateoutput gate三個“閘門”結構依先後順序構成。在每一個gate內部,相關的輸入都匹配專門的權重矩陣,各個輸入相加後都將匹配專門的bias向量,總體求和後需要通過專門的啟用函式進行處理形成輸出。 設定當期(即t期)輸入為 x

t x_t ,前一期輸出為 o t 1
o_{t-1}

input gate

input gate實際上是類似於一個filter,即用sigmoid啟用函式的啟用值過濾或加權實際的input。實際的input為:
i = t a n

h ( x t W i x + o t 1 W i o + b i ) i=tanh(x_t W_{i}^x+o_{t-1} W_{i}^o+b_{i})
sigmoid啟用函式filter為:
I G = s i g m o i d ( x t W I G x + o t 1 W I G o + b I G ) IG=sigmoid(x_t W_{IG}^x+o_{t-1} W_{IG}^o+b_{IG})
input gate層的最終輸出就是 I I I G IG 的點乘,即元素層面的對應相乘。
I o u t = i I G I_{out}=i \circ IG

inner state s t s_t

LSTM較於普通RNN網路增加了一個內部狀態量 s t s_t . 記憶的控制就是通過forget gate對於 s t 1 s_{t-1} 的過濾而發揮作用。

forget gate

與input gate相同,forget gate也是一個sigmoid啟用函式啟用值形成的filter,用於對上一期的狀態量 s t 1 s_{t-1} 進行過濾。
F G = s i g m o i d ( x t W F G x + o t 1 W F G o + b F G ) FG=sigmoid(x_t W_{FG}^x+o_{t-1} W_{FG}^o+b_{FG})
當期的狀態量 s t s_t 就是input gate層的輸出值與IG過濾後的上一期狀態量的簡單相加的結果。注意這裡的操作僅為簡單的相加,並沒有加入權重,不存在相乘,也沒有使用新的啟用函式,這一步驟是消除RNN反向傳播網路梯度消失或梯度爆炸的關鍵:
s t = s t 1 F G + I o u t s_t=s_{t-1} \circ FG + I_{out}

output gate

同之前的兩個gate類似,output gate也是一個sigmoid啟用函式filter,對當期的狀態量 s t s_t 進行過濾。 s t s_t 在接受過濾前,先使用tanh啟用函式進行區間壓縮:
O G = s i g m o i d ( x t W O G x + o t 1 W O G o + b O G ) OG=sigmoid(x_t W_{OG}^x+o_{t-1} W_{OG}^o+b_{OG})
以此對壓縮後的 s t s_t 進行過濾,形成最終當期計算單元的最終輸出:
o t = t a n h ( s t ) O G o_t=tanh(s_t) \circ OG
o t o_t s t s_t 將可用於下一期(t+1)計算單元的內部計算。