1. 程式人生 > >RNN 梯度消失和梯度爆炸

RNN 梯度消失和梯度爆炸

為什麼會梯度爆炸或梯度消失:

梯度爆炸指的是在訓練時,累計了很大的誤差導數,導致神經網路模型大幅更新。這樣模型會變得很不穩定,不能從訓練資料中很好的進行學習。極端情況下會得到nan.

會發生這個的原因是在神經網路層間不斷的以指數級在乘以導數。

補充:雅克比矩陣 -- 函式的一階偏導數以一定方式排列成的矩陣,舉個例子:

可以看到,除對角線元素外,其他元素都是0.而對角線上的元素值就是對應的y與x的一階偏導數值。

 

 

RNN部分:

在反向傳播時,要求誤差函式對W的導數

E是誤差函式,表示對所有時刻t的權重偏導數求和

使用鏈式推導展開可知,上式可以表示為:

說明: E就是輸出值y與真值之間的函式,而y又由ht線性變換然後過啟用函式得到;ht由輸入和之前的hk的函式得到,hk與藥訓練的引數W有函式關係

這裡面比較關鍵的就是ht與hk的偏導數關係

繼續使用鏈式推導可以知道,

對上式再展開一點,令

其實就是W矩陣

所以(diag表示雅克比矩陣的對角線)

當序列長度越長,對一個序列反向傳播的每一步都要計算一個連乘項

也就是W的連乘

當W<1或W>1時,很容易因為連乘的指數增長而發生梯度消失和梯度爆炸

 

梯度消失與梯度爆炸和啟用函式:

常用的啟用函式sigmoid和tanh

在梯度很小火梯度很大時,函式都是很平滑的,很容易導致越往後訓練,梯度幾乎不變。因此產生了梯度消失或梯度爆炸的問題

解決梯度爆照和梯度消失問題:

幾個tricks:

1、gradient clipping:

2、逆置輸入

之前正序輸入的時候,整個句子輸入後,才開始decode第一個輸入的詞,所以每一個詞都有長距離的依賴。但是逆置輸入之後,每次decode的時候只有1個時間步之差,然後用這個資訊來處理句子後續的資訊,減少了過長的依賴。

3、identity initialization

恆等函式identity function f(x)=x是不擔心多次迭代的,如果計算接近恆等函式的話,就會相對比較穩定。identity RNN就是一種RNN模型,啟用函式全都是relu,中間的recuurent weight初始化為恆等矩陣

4、LSTM

使用LSTM可以更好的記住長時間前的資訊

5、weight regularization

就是正則化,在loss函式後面加L1或L2範數的懲罰