1. 程式人生 > >LSTM網路層詳解及其應用例項

LSTM網路層詳解及其應用例項

上一節我們介紹了RNN網路層的記憶性原理,同時使用了keras框架聽過的SimpleRNN網路層到實際運用中。然而使用的效果並不理想,主要是因為simpleRNN無法應對過長單詞串的輸入,在理論上,當它接收第t個輸入時,它應該能把前面好幾個單詞的處理資訊記錄下來,但實際上它無法把前面已經處理過的單詞資訊保留到第t個單詞輸入的時刻。

出現這種現象的原因叫”Vanishing gradian problem”,我們以前說要更新某個鏈路權重中,需要對它求偏導數,但在某種情況下,我們求得的偏導數可能接近於0,這樣一來鏈路權重就得不到有效更新,因為當權重加上一個很接近於0的偏導數時,它不會產生顯著的變化。這種現象也會出現在feed forward網路,當網路有很多層時,我們會把誤差進行回傳,但層次過多時,回傳的誤差會不斷的被“沖淡”,直到某個神經元接收到回傳的誤差是,該誤差的值幾乎與0差不多大小,這樣求出的偏導數也接近與0,因此鏈路權重就得不到有效的更新。

這種現象被人工置頂的三位大牛Hochreiter,Schmidhuber,Bengio深入研究後,他們提出一種新型網路層叫LSTM和GRU以便接近偏導數接近於0使得鏈路權重得不到有效更新的問題。LSTM的全稱是Long Short term memory,也就是長短程記憶,它其實是我們上節使用的simpleRNN變種,設想當單詞一個個輸入網路時,旁邊還有一條傳送帶把相關資訊也輸入網路,如下圖:

螢幕快照 2018-09-07 下午6.10.13.png

這裡我們多增加一個變數C來記錄每一個單詞被網路處理後遺留下來的資訊,網路的啟用函式還是不變,但是我們要增加多幾個變數來計算變數C:
i_t = activation(dot(state_t, Ui) + dot(input_t, wi) + bi)
f_t =activation(dot(state_t, Uf) + dot(input_t, wf) + bf)
k_t=activation(dot(state_t, Uk) + dot(input_t, wk) + bk)
那麼C的更新方式為:
C = i_t * k_t + C*f_t
初看起來,邏輯很難理解,為何我們要增加這些不知所云的步驟呢,它蘊含著較為複雜的設計原理和數學原理,簡單來說C*f_t目的是增加一些噪音,讓網路適當的“忘記”以前計算留下了的資訊,i_t*k_t是讓網路增強最近幾次計算所遺留下來的資訊 ,這裡我們不深究,只要囫圇吞棗,知道新增加的變數C是用來幫助網路增強對以前資訊處理的記憶,並指導該變數如何更新就好,接下來我們看看LSTM網路的具體應用例項:

from keras.layers import LSTM

model = Sequential()
model.add(Embedding(max_features, 32))
model.add(LSTM(32))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc'])
history = model.fit(input_train, y_train, epochs=10, batch_size=128
, validation_split=0.2)

我們繼續使用上一節的資料作為網路輸入,上面程式碼執行後,我們再將它的訓練結果繪製出來,結果如下:

螢幕快照 2018-09-11 下午4.26.15.png

上一節我們使用SimpleRNN網路層時,網路對校驗資料的判斷準確率為85%左右,這裡我們使用LSTM網路層,網路對校驗資料的準確率可以提升到89%,這是因為LSTM比SimpleRNN對網路以期出來過的資料有更好的“記憶”功能,更能將以前處理過的單詞與現在處理的單詞關聯起來。

更詳細的講解和程式碼除錯演示過程,請點選連結

更多技術資訊,包括作業系統,編譯器,面試演算法,機器學習,人工智慧,請關照我的公眾號:
這裡寫圖片描述