RNN 梯度消失和梯度爆炸
為什麼會梯度爆炸或梯度消失:
梯度爆炸指的是在訓練時,累計了很大的誤差導數,導致神經網路模型大幅更新。這樣模型會變得很不穩定,不能從訓練資料中很好的進行學習。極端情況下會得到nan.
會發生這個的原因是在神經網路層間不斷的以指數級在乘以導數。
補充:雅克比矩陣 -- 函式的一階偏導數以一定方式排列成的矩陣,舉個例子:
可以看到,除對角線元素外,其他元素都是0.而對角線上的元素值就是對應的y與x的一階偏導數值。
RNN部分:
在反向傳播時,要求誤差函式對W的導數
E是誤差函式,表示對所有時刻t的權重偏導數求和
使用鏈式推導展開可知,上式可以表示為:
說明: E就是輸出值y與真值之間的函式,而y又由ht線性變換然後過啟用函式得到;ht由輸入和之前的hk的函式得到,hk與藥訓練的引數W有函式關係
這裡面比較關鍵的就是ht與hk的偏導數關係
繼續使用鏈式推導可以知道,
對上式再展開一點,令
則
而其實就是W矩陣
所以(diag表示雅克比矩陣的對角線)
當序列長度越長,對一個序列反向傳播的每一步都要計算一個連乘項
也就是W的連乘
當W<1或W>1時,很容易因為連乘的指數增長而發生梯度消失和梯度爆炸
梯度消失與梯度爆炸和啟用函式:
常用的啟用函式sigmoid和tanh
在梯度很小火梯度很大時,函式都是很平滑的,很容易導致越往後訓練,梯度幾乎不變。因此產生了梯度消失或梯度爆炸的問題
解決梯度爆照和梯度消失問題:
幾個tricks:
1、gradient clipping:
2、逆置輸入
之前正序輸入的時候,整個句子輸入後,才開始decode第一個輸入的詞,所以每一個詞都有長距離的依賴。但是逆置輸入之後,每次decode的時候只有1個時間步之差,然後用這個資訊來處理句子後續的資訊,減少了過長的依賴。
3、identity initialization
恆等函式identity function f(x)=x是不擔心多次迭代的,如果計算接近恆等函式的話,就會相對比較穩定。identity RNN就是一種RNN模型,啟用函式全都是relu,中間的recuurent weight初始化為恆等矩陣
4、LSTM
使用LSTM可以更好的記住長時間前的資訊
5、weight regularization
就是正則化,在loss函式後面加L1或L2範數的懲罰