1. 程式人生 > >長短期記憶神經網路(LSTM)介紹以及簡單應用分析

長短期記憶神經網路(LSTM)介紹以及簡單應用分析

本文分為四個部分,第一部分簡要介紹LSTM的應用現狀;第二部分介紹LSTM的發展歷史,並引出了受眾多學者關注的LSTM變體——門控遞迴單元(GRU);第三部分介紹LSTM的基本結構,由基本迴圈神經網路結構引出LSTM的具體結構。第四部分,應用Keras框架提供的API,比較和分析簡單迴圈神經網路(SRN)、LSTM和GRU在手寫數字minist資料集上的表現。

 


應用現狀

       長短期記憶神經網路(LSTM)是一種特殊的迴圈神經網路(RNN)。原始的RNN在訓練中,隨著訓練時間的加長以及網路層數的增多,很容易出現梯度爆炸或者梯度消失的問題,導致無法處理較長序列資料,從而無法獲取長距離資料的資訊。

       LSTM應用的領域包括:文字生成、機器翻譯、語音識別、生成影象描述和視訊標記等。

      2009年, 應用LSTM搭建的神經網路模型贏得了ICDAR手寫識別比賽冠軍。

      2015年以來,在機械故障診斷和預測領域,相關學者應用LSTM來處理機械裝置的振動訊號。

      2016年, 谷歌公司應用LSTM來做語音識別和文字翻譯,其中Google翻譯用的就是一個7-8層的LSTM模型。

      2016年, 蘋果公司使用LSTM來優化Siri應用。

 


發展歷史

        1997年,Sepp Hochreiter 和 Jürgen Schmidhuber[1]提出了長短期記憶神經網路(LSTM),有效解決了RNN難以解決的人為延長時間任務的問題,並解決了RNN容易出現梯度消失的問題。

        1999年,Felix A. Gers等人[2]發現[1]中提出的LSTM在處理連續輸入資料時,如果沒有重置網路內部的狀態,最終會導致網路崩潰。因此,他們在文獻[1]基礎上引入了遺忘門機制,使得LSTM能夠重置自己的狀態。

         2000年,Felix A. Gers和Jiirgen Schmidhuber[3]發現,通過在LSTM內部狀態單元內新增窺視孔(Peephole)連線,可以增強網路對輸入序列之間細微特徵的區分能力。

         2005年,Alex Graves和Jürgen Schmidhuber[4]在文獻[1] [2] [3]的基礎上提出了一種雙向長短期記憶神經網路(BLSTM),也稱為vanilla LSTM,是當前應用最廣泛的一種LSTM模型。

         2005年-2015年期間,相關學者提出了多種LSTM變體模型,此處不多做描述。

         2016年,Klaus Greff 等人[5]回顧了LSTM的發展歷程,並比較分析了八種LSTM變體在語音識別、手寫識別和絃音樂建模方面的能力,實驗結果表明這些變體不能顯著改進標準LSTM體系結構,並證明了遺忘門和輸出啟用功能是LSTM的關鍵組成部分。在這八種變體中,vanilla LSTM的綜合表現能力最佳。另外,還探索了LSTM相關超引數的設定影響,實驗結果表明學習率是最關鍵的超引數,其次是網路規模(網路層數和隱藏層單元數),而動量梯度等設定對最終結果影響不大。

        下圖展示了Simple RNN(圖左)和vanilla LSTM(圖右,圖中藍色線條表示窺視孔連線)的基本單元結構圖[5]:

        在眾多LSTM變體中,2014年Kyunghyun Cho等人[6]提出的變體引起了眾多學者的關注。Kyunghyun Cho等人簡化了LSTM架構,稱為門控遞迴單元(GRU)。GRU擺脫了單元狀態,基本結構由重置門和更新門組成。LSTM和GRU的基本結構單元如下圖(具體可參考:Illustrated Guide to LSTM’s and GRU’s: A step by step explanation)。

 

        在GRU被提出後,Junyoung Chung等人[7]比較了LSTM和GRU在複音音樂和語音訊號建模方面的能力,實驗結果表明GRU和LSTM表現相當。

        GRU被提出至今(2019年),也只有幾年時間,關於它的一些應用利弊到目前還未探索清楚。不過,相對於LSTM架構,GRU的的引數較少,在資料量較大的情況下,其訓練速度更快。

         LSTM是深度學習技術中的一員,其基本結構比較複雜,計算複雜度較高,導致較難進行較深層次的學習,例如谷歌翻譯也只是應用7-8層的LSTM網路結構。另外,在訓練學習過程中有可能會出現過擬合,可以通過應用dropout來解決過擬合問題(這在Keras等框架中均有實現,具體可參考:LSTM原理與實踐,原來如此簡單)。

         LSTM在當前應用比較的結構是雙向LSTM或者多層堆疊LSTM,這兩種結構的實現在Keras等框架中均有對應的API可以呼叫。

        下圖展示一個堆疊兩層的LSTM結構圖(來源:運用TensorFlow處理簡單的NLP問題):

 

        下圖展示了一個雙向LSTM的結構圖(來源:雙向LSTM)

 

 

 

 


基本原理

        本節首先講解一下RNN的基本結構,然後說明LSTM的具體原理(下面要介紹的LSTM即為vanilla LSTM)。

        原始的RNN基本結構圖如下圖所示(原圖來源:Understanding LSTM Networks)。

       由上圖可知,RNN展開後由多個相同的單元連續連線。但是,RNN的實際結構確和上圖左邊的結構所示,是一個自我不斷迴圈的結構。即隨著輸入資料的不斷增加,上述自我迴圈的結構把上一次的狀態傳遞給當前輸入,一起作為新的輸入資料進行當前輪次的訓練和學習,一直到輸入或者訓練結束,最終得到的輸出即為最終的預測結果。

        LSTM是一種特殊的RNN,兩者的區別在於普通的RNN單個迴圈結構內部只有一個狀態。而LSTM的單個迴圈結構(又稱為細胞)內部有四個狀態。相比於RNN,LSTM迴圈結構之間保持一個持久的單元狀態不斷傳遞下去,用於決定哪些資訊要遺忘或者繼續傳遞下去。

        包含三個連續迴圈結構的RNN如下圖,每個迴圈結構只有一個輸出:

        包含三個連續迴圈結構的LSTM如下圖,每個迴圈結構有兩個輸出,其中一個即為單元狀態:

        一層LSTM是由單個迴圈結構結構組成,既由輸入資料的維度和迴圈次數決定單個迴圈結構需要自我更新幾次,而不是多個單個迴圈結構連線組成(此處關於這段描述,在實際操作的理解詳述請參考:Keras關於LSTM的units引數,還是不理解? ),即當前層LSTM的引數總個數只需計算一個迴圈單元就行,而不是計算多個連續單元的總個數。

       下面將由一組圖來詳細結構LSTM細胞的基本組成和實現原理。LSTM細胞由輸入門、遺忘門、輸出門和單元狀態組成。

  • 輸入門:決定當前時刻網路的輸入資料有多少需要儲存到單元狀態。
  • 遺忘門:決定上一時刻的單元狀態有多少需要保留到當前時刻。
  • 輸出門:控制當前單元狀態有多少需要輸出到當前的輸出值。

       下圖展示了應用上一個時刻的輸出h_t-1和當前的資料輸入x_t,通過遺忘門得到f_t的過程。(下面的一組原圖來源:Understanding LSTM Networks)

       下圖展示了應用上一個時刻的輸出h_t-1和當前的資料輸入x_t,通過輸入門得到i_t,以及通過單元狀態得到當前時刻暫時狀態C~t的過程。

       下圖展示了應用上一個細胞結構的單元狀態C_t-1、遺忘門輸出f_t、輸入門輸出i_t以及單元狀態的輸出C~t,得到當前細胞的狀態C_t的過程。

       下圖展示了應用上一個時刻的輸出h_t-1和當前的資料輸入x_t,通過輸出門得到o_t的過程,以及結合當前細胞的單元狀態C_t和o_t得到最終的輸出h_t的過程。

 

 


基於Keras框架的手寫數字識別實驗

        本節應用Keras提供的API,比較和分析Simple RNN、LSTM和GRU在手寫數字mnist資料集上的預測準確率。

應用Simple RNN進行手寫數字預測訓練的程式碼如下:

import keras
from keras.layers import LSTM , SimpleRNN, GRU
from keras.layers import Dense, Activation
from keras.datasets import mnist
from keras.models import Sequential
from keras.optimizers import Adam
learning_rate = 0.001 training_iters = 20 batch_size = 128 display_step = 10 n_input = 28 n_step = 28 n_hidden = 128 n_classes = 10 (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(-1, n_step, n_input) x_test = x_test.reshape(-1, n_step, n_input) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = keras.utils.to_categorical(y_train, n_classes) y_test = keras.utils.to_categorical(y_test, n_classes) model = Sequential() model.add(SimpleRNN(n_hidden, batch_input_shape=(None, n_step, n_input), unroll=True)) model.add(Dense(n_classes)) model.add(Activation('softmax')) adam = Adam(lr=learning_rate) model.summary() model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy']) history = model.fit(x_train, y_train, batch_size=batch_size, epochs=training_iters, verbose=1, validation_data=(x_test, y_test)) scores = model.evaluate(x_test, y_test, verbose=0) print('Simple RNN test score(loss value):', scores[0]) print('Simple RNN test accuracy:', scores[1])

訓練結果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
simple_rnn_1 (SimpleRNN)     (None, 128)               20096     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_1 (Activation)    (None, 10)                0         
=================================================================
Total params: 21,386
Trainable params: 21,386
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 3s 51us/step - loss: 0.4584 - acc: 0.8615 - val_loss: 0.2459 - val_acc: 0.9308
Epoch 2/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1923 - acc: 0.9440 - val_loss: 0.1457 - val_acc: 0.9578
Epoch 3/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1506 - acc: 0.9555 - val_loss: 0.1553 - val_acc: 0.9552
Epoch 4/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1326 - acc: 0.9604 - val_loss: 0.1219 - val_acc: 0.9642
Epoch 5/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1184 - acc: 0.9651 - val_loss: 0.1014 - val_acc: 0.9696
Epoch 6/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.1021 - acc: 0.9707 - val_loss: 0.1254 - val_acc: 0.9651
Epoch 7/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0987 - acc: 0.9708 - val_loss: 0.0946 - val_acc: 0.9733
Epoch 8/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0959 - acc: 0.9722 - val_loss: 0.1163 - val_acc: 0.9678
Epoch 9/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0888 - acc: 0.9742 - val_loss: 0.0983 - val_acc: 0.9718
Epoch 10/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0833 - acc: 0.9750 - val_loss: 0.1199 - val_acc: 0.9651
Epoch 11/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0814 - acc: 0.9750 - val_loss: 0.0939 - val_acc: 0.9722
Epoch 12/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0767 - acc: 0.9773 - val_loss: 0.0865 - val_acc: 0.9761
Epoch 13/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0747 - acc: 0.9778 - val_loss: 0.1077 - val_acc: 0.9697
Epoch 14/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0746 - acc: 0.9779 - val_loss: 0.1098 - val_acc: 0.9693
Epoch 15/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0671 - acc: 0.9799 - val_loss: 0.0776 - val_acc: 0.9771
Epoch 16/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0639 - acc: 0.9810 - val_loss: 0.0961 - val_acc: 0.9730
Epoch 17/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0701 - acc: 0.9792 - val_loss: 0.1046 - val_acc: 0.9713
Epoch 18/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0600 - acc: 0.9822 - val_loss: 0.0865 - val_acc: 0.9767
Epoch 19/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0635 - acc: 0.9813 - val_loss: 0.0812 - val_acc: 0.9790
Epoch 20/20
60000/60000 [==============================] - 3s 47us/step - loss: 0.0579 - acc: 0.9827 - val_loss: 0.0981 - val_acc: 0.9733
Simple RNN test score(loss value): 0.09805978989955037
Simple RNN test accuracy: 0.9733

        可知Simple RNN在測試集上的最終預測準確率為97.33%。

        只需修改下方程式碼中Simple RNN為LSTM,即可呼叫LSTM進行模型訓練:

model.add(SimpleRNN(n_hidden,
               batch_input_shape=(None, n_step, n_input),
               unroll=True))

改變為:

model.add(LSTM(n_hidden,
               batch_input_shape=(None, n_step, n_input),
               unroll=True))

訓練結果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
lstm_1 (LSTM)                (None, 128)               80384     
_________________________________________________________________
dense_2 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_2 (Activation)    (None, 10)                0         
=================================================================
Total params: 81,674
Trainable params: 81,674
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 10s 172us/step - loss: 0.5226 - acc: 0.8277 - val_loss: 0.1751 - val_acc: 0.9451
Epoch 2/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.1474 - acc: 0.9549 - val_loss: 0.1178 - val_acc: 0.9641
Epoch 3/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.1017 - acc: 0.9690 - val_loss: 0.0836 - val_acc: 0.9748
Epoch 4/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0764 - acc: 0.9764 - val_loss: 0.0787 - val_acc: 0.9759
Epoch 5/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0607 - acc: 0.9811 - val_loss: 0.0646 - val_acc: 0.9813
Epoch 6/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0542 - acc: 0.9834 - val_loss: 0.0630 - val_acc: 0.9801
Epoch 7/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0452 - acc: 0.9859 - val_loss: 0.0603 - val_acc: 0.9803
Epoch 8/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0406 - acc: 0.9874 - val_loss: 0.0531 - val_acc: 0.9849
Epoch 9/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0345 - acc: 0.9888 - val_loss: 0.0540 - val_acc: 0.9834
Epoch 10/20
60000/60000 [==============================] - 8s 132us/step - loss: 0.0305 - acc: 0.9901 - val_loss: 0.0483 - val_acc: 0.9848
Epoch 11/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0281 - acc: 0.9913 - val_loss: 0.0517 - val_acc: 0.9843
Epoch 12/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0256 - acc: 0.9918 - val_loss: 0.0472 - val_acc: 0.9847
Epoch 13/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0229 - acc: 0.9929 - val_loss: 0.0441 - val_acc: 0.9874
Epoch 14/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0204 - acc: 0.9935 - val_loss: 0.0490 - val_acc: 0.9855
Epoch 15/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0192 - acc: 0.9938 - val_loss: 0.0486 - val_acc: 0.9851
Epoch 16/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0203 - acc: 0.9937 - val_loss: 0.0450 - val_acc: 0.9866
Epoch 17/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0160 - acc: 0.9948 - val_loss: 0.0391 - val_acc: 0.9882
Epoch 18/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0147 - acc: 0.9955 - val_loss: 0.0544 - val_acc: 0.9834
Epoch 19/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0147 - acc: 0.9953 - val_loss: 0.0456 - val_acc: 0.9880
Epoch 20/20
60000/60000 [==============================] - 8s 133us/step - loss: 0.0153 - acc: 0.9952 - val_loss: 0.0465 - val_acc: 0.9867
LSTM test score(loss value): 0.046479647984029725
LSTM test accuracy: 0.9867

       可知LSTM在測試集上的最終預測準確率為98.67%。

       採用同樣的思路,把Simple RNN改為GRU,即可呼叫GRU進行模型訓練。

訓練結果:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
gru_1 (GRU)                  (None, 128)               60288     
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290      
_________________________________________________________________
activation_3 (Activation)    (None, 10)                0         
=================================================================
Total params: 61,578
Trainable params: 61,578
Non-trainable params: 0
_________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 10s 166us/step - loss: 0.6273 - acc: 0.7945 - val_loss: 0.2062 - val_acc: 0.9400
Epoch 2/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.1656 - acc: 0.9501 - val_loss: 0.1261 - val_acc: 0.9606
Epoch 3/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.1086 - acc: 0.9667 - val_loss: 0.0950 - val_acc: 0.9697
Epoch 4/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0824 - acc: 0.9745 - val_loss: 0.0761 - val_acc: 0.9769
Epoch 5/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0644 - acc: 0.9797 - val_loss: 0.0706 - val_acc: 0.9793
Epoch 6/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0540 - acc: 0.9829 - val_loss: 0.0678 - val_acc: 0.9799
Epoch 7/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0479 - acc: 0.9854 - val_loss: 0.0601 - val_acc: 0.9811
Epoch 8/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0402 - acc: 0.9877 - val_loss: 0.0495 - val_acc: 0.9848
Epoch 9/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0346 - acc: 0.9895 - val_loss: 0.0591 - val_acc: 0.9821
Epoch 10/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0306 - acc: 0.9901 - val_loss: 0.0560 - val_acc: 0.9836
Epoch 11/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0290 - acc: 0.9910 - val_loss: 0.0473 - val_acc: 0.9857
Epoch 12/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0249 - acc: 0.9922 - val_loss: 0.0516 - val_acc: 0.9852
Epoch 13/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0222 - acc: 0.9930 - val_loss: 0.0448 - val_acc: 0.9863
Epoch 14/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0206 - acc: 0.9934 - val_loss: 0.0453 - val_acc: 0.9872
Epoch 15/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0178 - acc: 0.9944 - val_loss: 0.0559 - val_acc: 0.9833
Epoch 16/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0173 - acc: 0.9947 - val_loss: 0.0502 - val_acc: 0.9854
Epoch 17/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0150 - acc: 0.9955 - val_loss: 0.0401 - val_acc: 0.9880
Epoch 18/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0164 - acc: 0.9949 - val_loss: 0.0486 - val_acc: 0.9872
Epoch 19/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0133 - acc: 0.9960 - val_loss: 0.0468 - val_acc: 0.9882
Epoch 20/20
60000/60000 [==============================] - 8s 130us/step - loss: 0.0107 - acc: 0.9965 - val_loss: 0.0470 - val_acc: 0.9879
GRU test score(loss value): 0.04698457587567973
GRU test accuracy: 0.9879

       可知GRU在測試集上的最終預測準確率為98.79%。

      由上述實驗結果可知,LSTM和GRU的預測準確率要顯著高於Simple RNN,而LSTM和GRU的預測準確率相差較小。

 

 


參考文獻

[1] S. Hochreiter and J. Schmidhuber, “Long Short-Term Memory,” Neural Comput, vol. 9, no. 8, pp. 1735–1780, Nov. 1997.

[2] F. A. Gers, J. Schmidhuber, and F. A. Cummins, “Learning to Forget: Continual Prediction with LSTM,” Neural Comput., vol. 12, pp. 2451–2471, 2000.

[3] F. A. Gers and J. Schmidhuber, “Recurrent nets that time and count,” Proc. IEEE-INNS-ENNS Int. Jt. Conf. Neural Netw. IJCNN 2000 Neural Comput. New Chall. Perspect. New Millenn., vol. 3, pp. 189–194 vol.3, 2000.

[4] A. Graves and J. Schmidhuber, “Framewise phoneme classification with bidirectional LSTM and other neural network architectures,” Neural Netw., vol. 18, no. 5, pp. 602–610, Jul. 2005.

[5] K. Greff, R. K. Srivastava, J. Koutník, B. R. Steunebrink, and J. Schmidhuber, “LSTM: A Search Space Odyssey,” IEEE Trans. Neural Netw. Learn. Syst., vol. 28, no. 10, pp. 2222–2232, Oct. 2017.

[6] K. Cho et al., “Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation,” ArXiv14061078 Cs Stat, Jun. 2014.

[7] J. Chung, C. Gulcehre, K. Cho, and Y. Bengio, “Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling,” ArXiv14123555 Cs, Dec. 2014.

&n