1. 程式人生 > >迴圈神經網路(RNN)模型與前向反向傳播演算法

迴圈神經網路(RNN)模型與前向反向傳播演算法

    在前面我們講到了DNN,以及DNN的特例CNN的模型和前向反向傳播演算法,這些演算法都是前向反饋的,模型的輸出和模型本身沒有關聯關係。今天我們就討論另一類輸出和模型間有反饋的神經網路:迴圈神經網路(Recurrent Neural Networks ,以下簡稱RNN),它廣泛的用於自然語言處理中的語音識別,手寫書別以及機器翻譯等領域。

1. RNN概述

    在前面講到的DNN和CNN中,訓練樣本的輸入和輸出是比較的確定的。但是有一類問題DNN和CNN不好解決,就是訓練樣本輸入是連續的序列,且序列的長短不一,比如基於時間的序列:一段段連續的語音,一段段連續的手寫文字。這些序列比較長,且長度不一,比較難直接的拆分成一個個獨立的樣本來通過DNN/CNN進行訓練。

    而對於這類問題,RNN則比較的擅長。那麼RNN是怎麼做到的呢?RNN假設我們的樣本是基於序列的。比如是從序列索引1到序列索引$\tau$的。對於這其中的任意序列索引號$t$,它對應的輸入是對應的樣本序列中的$x^{(t)}$。而模型在序列索引號$t$位置的隱藏狀態$h^{(t)}$,則由$x^{(t)}$和在$t-1$位置的隱藏狀態$h^{(t-1)}$共同決定。在任意序列索引號$t$,我們也有對應的模型預測輸出$o^{(t)}$。通過預測輸出$o^{(t)}$和訓練序列真實輸出$y^{(t)}$,以及損失函式$L^{(t)}$,我們就可以用DNN類似的方法來訓練模型,接著用來預測測試序列中的一些位置的輸出。

    下面我們來看看RNN的模型。

2. RNN模型

    RNN模型有比較多的變種,這裡介紹最主流的RNN模型結構如下:

    上圖中左邊是RNN模型沒有按時間展開的圖,如果按時間序列展開,則是上圖中的右邊部分。我們重點觀察右邊部分的圖。

    這幅圖描述了在序列索引號$t$附近RNN的模型。其中:

    1)$x^{(t)}$代表在序列索引號$t$時訓練樣本的輸入。同樣的,$x^{(t-1)}$和$x^{(t+1)}$代表在序列索引號$t-1$和$t+1$時訓練樣本的輸入。

    2)$h^{(t)}$代表在序列索引號$t$時模型的隱藏狀態。$h^{(t)}$由$x^{(t)}$和$h^{(t-1)}$共同決定。

    3)$o^{(t)}$代表在序列索引號$t$時模型的輸出。$o^{(t)}$只由模型當前的隱藏狀態$h^{(t)}$決定。

    4)$L^{(t)}$代表在序列索引號$t$時模型的損失函式。

    5)$y^{(t)}$代表在序列索引號$t$時訓練樣本序列的真實輸出。

    6)$U,W,V$這三個矩陣是我們的模型的線性關係引數,它在整個RNN網路中是共享的,這點和DNN很不相同。 也正因為是共享了,它體現了RNN的模型的“迴圈反饋”的思想。  

3. RNN前向傳播演算法

    有了上面的模型,RNN的前向傳播演算法就很容易得到了。

    對於任意一個序列索引號$t$,我們隱藏狀態$h^{(t)}$由$x^{(t)}$和$h^{(t-1)}$得到:$$h^{(t)} = \sigma(z^{(t)}) = \sigma(Ux^{(t)} + Wh^{(t-1)} +b )$$

    其中$\sigma$為RNN的啟用函式,一般為$tanh$, $b$為線性關係的偏倚。

    序列索引號$t$時模型的輸出$o^{(t)}$的表示式比較簡單:$$o^{(t)} = Vh^{(t)} +c $$

    在最終在序列索引號$t$時我們的預測輸出為:$$\hat{y}^{(t)} = \sigma(o^{(t)})$$

    通常由於RNN是識別類的分類模型,所以上面這個啟用函式一般是softmax。

    通過損失函式$L^{(t)}$,比如對數似然損失函式,我們可以量化模型在當前位置的損失,即$\hat{y}^{(t)}$和$y^{(t)}$的差距。

4. RNN反向傳播演算法推導

    有了RNN前向傳播演算法的基礎,就容易推匯出RNN反向傳播演算法的流程了。RNN反向傳播演算法的思路和DNN是一樣的,即通過梯度下降法一輪輪的迭代,得到合適的RNN模型引數$U,W,V,b,c$。由於我們是基於時間反向傳播,所以RNN的反向傳播有時也叫做BPTT(back-propagation through time)。當然這裡的BPTT和DNN也有很大的不同點,即這裡所有的$U,W,V,b,c$在序列的各個位置是共享的,反向傳播時我們更新的是相同的引數。

    為了簡化描述,這裡的損失函式我們為對數損失函式,輸出的啟用函式為softmax函式,隱藏層的啟用函式為tanh函式。

    對於RNN,由於我們在序列的每個位置都有損失函式,因此最終的損失$L$為:$$L = \sum\limits_{t=1}^{\tau}L^{(t)}$$

    其中$V,c,$的梯度計算是比較簡單的:$$\frac{\partial L}{\partial c} = \sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial c} = \sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial c} = \sum\limits_{t=1}^{\tau}\hat{y}^{(t)} - y^{(t)}$$$$\frac{\partial L}{\partial V} =\sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial V} = \sum\limits_{t=1}^{\tau}\frac{\partial L^{(t)}}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial V} = \sum\limits_{t=1}^{\tau}(\hat{y}^{(t)} - y^{(t)}) (h^{(t)})^T$$

    但是$W,U,b$的梯度計算就比較的複雜了。從RNN的模型可以看出,在反向傳播時,在在某一序列位置t的梯度損失由當前位置的輸出對應的梯度損失和序列索引位置$t+1$時的梯度損失兩部分共同決定。對於$W$在某一序列位置t的梯度損失需要反向傳播一步步的計算。我們定義序列索引$t$位置的隱藏狀態的梯度為:$$\delta^{(t)} = \frac{\partial L}{\partial h^{(t)}}$$

    這樣我們可以像DNN一樣從$\delta^{(t+1)} $遞推$\delta^{(t)}$ 。$$\delta^{(t)} =\frac{\partial L}{\partial o^{(t)}} \frac{\partial o^{(t)}}{\partial h^{(t)}} + \frac{\partial L}{\partial h^{(t+1)}}\frac{\partial h^{(t+1)}}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)}) + W^T\delta^{(t+1)}diag(1-(h^{(t+1)})^2)$$

    對於$\delta^{(\tau)} $,由於它的後面沒有其他的序列索引了,因此有:$$\delta^{(\tau)} =\frac{\partial L}{\partial o^{(\tau)}} \frac{\partial o^{(\tau)}}{\partial h^{(\tau)}} = V^T(\hat{y}^{(\tau)} - y^{(\tau)})$$

    有了$\delta^{(t)} $,計算$W,U,b$就容易了,這裡給出$W,U,b$的梯度計算表示式:$$\frac{\partial L}{\partial W} =  \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial W} = \sum\limits_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(h^{(t-1)})^T$$$$\frac{\partial L}{\partial b}= \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial b} = \sum\limits_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}$$$$\frac{\partial L}{\partial U} = \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial h^{(t)}} \frac{\partial h^{(t)}}{\partial U} = \sum\limits_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(x^{(t)})^T$$

    除了梯度表示式不同,RNN的反向傳播演算法和DNN區別不大,因此這裡就不再重複總結了。

5. RNN小結

    上面總結了通用的RNN模型和前向反向傳播演算法。當然,有些RNN模型會有些不同,自然前向反向傳播的公式會有些不一樣,但是原理基本類似。

    RNN雖然理論上可以很漂亮的解決序列資料的訓練,但是它也像DNN一樣有梯度消失時的問題,當序列很長的時候問題尤其嚴重。因此,上面的RNN模型一般不能直接用於應用領域。在語音識別,手寫書別以及機器翻譯等NLP領域實際應用比較廣泛的是基於RNN模型的一個特例LSTM,下一篇我們就來討論LSTM模型。

(歡迎轉載,轉載請註明出處。歡迎溝通交流:[email protected]) 

參考資料:

2) Deep Learning, book by Ian Goodfellow, Yoshua Bengio, and Aaron Courville