1. 程式人生 > >【機器學習】【RNN中的梯度消失與梯度爆炸】

【機器學習】【RNN中的梯度消失與梯度爆炸】

學習speech synthesis的Tacotron模型,而Tacotron是基於seq2seq attention,RNN中的一類。所以得先學習RNN,以及RNN的變種LSTM和GRU。

RNN的詳細我這裡不再介紹了,許多神犇的部落格及網上免費的課程講得都很詳細。這裡僅說明RNN中的梯度消失與梯度爆炸。文章若有錯誤,煩請大家批評指正。

以經典RNN為例,

假設我們的時間序列只有三段,S0為給定值,則RNN的前向傳播過程:

S1=tanh(Wx*X1+Ws*S0+b1),O1=Wy*S1+b2,y1=g(O1)=g(Wy*S1+b2)

S2=tanh(Wx*X2+Ws*S1+b1),O2=Wy*S2+b2,y2=g(O2)=g(Wy*S2+b2)

S3=tanh(Wx*X3+Ws*S2+b1),O3=Wy*S3+b2,y3=g(O3)=g(Wy*S3+b2)

其中Wx為處理輸入的引數,Wy為處理輸出的引數,Ws為處理前一個時間序列的引數。

假設損失函式為L=1/2*(Y-y)^2,即在t=3時刻,損失函式為L3=1/2*(Y3-y3)^2

對於每一次訓練,損失函式為L=∑(t=0,T)Lt,即每一時刻損失值的累加。

我們訓練RNN的目的就是不斷調整引數,即Wx、Ws、Wy和b1,b2,使得它們讓L儘可能達到最小。

假設我們的三段時間序列為t1,t2,t3。

我們考慮t3時刻,對t3時刻的Wx、Ws、Wy求偏導:

可以看出,時間序列對Wy沒有長期依賴,而對Wx和Ws的偏導會隨著時間序列的增加,中間的求積過程就會不斷增加。

因此,根據上面的求偏導公式,可以得到任意時刻對Wx的求偏導公式:

任意時刻對Ws的的求偏導公式和上面類似。

而其中,Sj對Sj-1的偏導數,就是

啟用函式tanh和它的導數影象如下:(引用自zhihu)

可以看出,啟用函式tanh的導數是小於等於1的,訓練的過程中大部分情況下也小於1,因為很少出現WxXj+WsSj-1+b1=0的情況。如果Ws是一個大於0小於1的值,那麼當t很大時,就會無窮小,即趨於0;當Ws很大時,則會趨於無窮。

因此,梯度消失和梯度爆炸的根本原因就是這一坨連乘,我們要儘量去掉這一坨連乘,一種辦法就是使另一種辦法就是使其實這就是LSTM做的事情。