1. 程式人生 > >長短期記憶(LSTM)系列_LSTM的資料準備(5)——如何配置Keras中截斷反向傳播預測的輸入序列步長

長短期記憶(LSTM)系列_LSTM的資料準備(5)——如何配置Keras中截斷反向傳播預測的輸入序列步長

導讀:

這篇文章是介紹了BPTT的概念,說明了資料截斷的原因和方法,即提高網路的學習效率。以及如何找到最好的截斷方法,即利用網格搜尋。

文中都是一些概念介紹,這裡直接把原文貼上來了。

原文連結:https://machinelearningmastery.com/truncated-backpropagation-through-time-in-keras/

正文翻譯如下:

遞迴神經網路能夠在序列預測問題中學習跨越多個時間步的時間依賴性。

像長短期記憶或LSTM網路這樣的現代遞迴神經網路通過稱為反向傳播時間的反向傳播演算法的變體進行訓練。為了提高具有很長序列的序列預測問題的效率,該演算法已被進一步修改,並且被稱為截斷反向傳播。

使用Truncated Backpropagation Through Time訓練像LSTM這樣的遞迴神經網路時,一個重要的配置引數決定了使用多少次步進作為輸入。也就是說,如何將非常長的輸入序列分成子序列以獲得最佳效能。

在這篇文章中,您將發現6種不同的方法,您可以使用Keras在Python中使用Truncated Backpropagation Through Time來分割非常長的輸入序列以有效地訓練遞迴神經網路。

閱讀這篇文章後,你會知道:

  • 通過時間截斷反向傳播是什麼以及如何在Python深度學習庫Keras中實現它。
  • 輸入時間步數的選擇究竟如何影響遞迴神經網路中的學習。
  • 您可以使用6種不同的技術來分割非常長的序列預測問題,以充分利用截斷反向傳播時間訓練演算法。

讓我們開始吧。

 

通過時間截斷反向傳播

反向傳播是用於更新神經網路中的權重的訓練演算法,以便最小化預期輸出和給定輸入的預測輸出之間的誤差。

對於序列預測問題,其中觀察之間存在順序依賴性,使用遞迴神經網路代替經典前饋神經網路。使用Backpropagation演算法的變體訓練遞迴神經網路,該演算法稱為Backpropagation Through Time,簡稱BPTT。

實際上,BPTT展開遞迴神經網路並在整個輸入序列上向後傳播誤差,一次一個步驟。然後用累積的梯度更新權重。

對於輸入序列很長的問題,BPTT可能很難訓練復現神經網路。除了速度之外,在很多時間步長上累積梯度可能導致值縮小到零,或者最終溢位或爆炸的值增長。

BPTT的修改是限制在向後傳遞上使用的時間步數,並且實際上估計用於更新權重的梯度而不是完全計算它。

這種變化稱為截斷反向傳播時間或TBPTT。

TBPTT訓練演算法有兩個引數:

  • k1:定義正向傳遞上顯示給網路的時間步數。
  • k2:定義在向後傳遞上估計漸變時要檢視的時間步數。

因此,當考慮如何配置訓練演算法時,我們可以使用符號TBPTT(k1,k2),其中k1 = k2 = n,其中n是經典非截斷BPTT的輸入序列長度。

TBPTT配置對RNN序列模型的影響

像LSTM這樣的現代遞迴神經網路可以使用它們的內部狀態來記住非常長的輸入序列。如超過數千次的步驟。

這意味著TBPTT的配置不一定通過選擇時間步長來定義您正在優化的網路記憶體。您可以選擇何時將網路的內部狀態與用於更新網路權重的機制分開重置。

相反,TBPTT引數的選擇會影響網路如何估計用於更新權重的誤差梯度。更一般地,配置定義了可以考慮網路來對序列問題建模的時間步數。

我們可以正式說明這樣的事情:

yhat(t) = f(X(t), X(t-1), X(t-2), ... X(t-n))

它在概念上與在時間序列問題上訓練的多層感知器上的視窗大小相似(但在實踐中完全不同),或者與ARIMA等線性時間序列模型的p和q引數相似。TBPTT定義了訓練期間模型輸入序列的範圍。如果yhat是特定時間步的輸出,則f(...)是遞迴神經網路近似的關係,X(t)是特定時間步的觀測值。

 

Keras實施TBPTT

Keras深度學習庫提供TBPTT的實現,用於訓練復現神經網路。

實施比上面列出的一般版本更受限制。

具體地,k1和k2值彼此相等並固定。

  • TBPTT(k1,k2),其中k1 = k2

這是通過訓練諸如長短期記憶網路或LSTM之類的遞迴神經網路所需的固定大小的三維輸入來實現的。

LSTM期望輸入資料具有尺寸:樣本,時間步和特徵。

它是此輸入格式的第二個維度,即時間步長,用於定義用於序列預測問題的前向和後向傳遞的時間步數。

因此,在為Keras中的序列預測問題準備輸入資料時,必須仔細選擇指定的時間步數。

時間步的選擇將影響兩者:

  • 在前進過程中積累的內部狀態。
  • 梯度估計用於更新後向傳遞的權重。

請注意,預設情況下,每次批處理後都會重置網路的內部狀態,但可以通過使用所謂的有狀態LSTM並手動呼叫重置操作來實現對內部狀態重置的更明確控制。

有關Keras中有狀態LSTM的更多資訊,請參閱帖子:

在Keras中準備TBPTT的序列資料

分解序列資料的方式將定義BPTT前向和後向傳遞中使用的時間步數。

因此,您必須仔細考慮如何準備訓練資料。

本節列出了您可以考慮的6種技術。

1.使用資料原樣

如果每個序列中的時間步數是適度的,例如幾十或幾百步,則可以按原樣使用輸入序列。

已經提出TBPTT的實際限制為約200至400倍步長。

如果序列資料小於或等於此範圍,則可以將序列觀察值重新整形為輸入資料的時間步長。

例如,如果您有一個包含25個時間步長的100個單變數序列的集合,則可以將其重新整形為100個樣本,25個時間步長和1個特徵或[100,25,1]。

2.樸素的資料拆分

如果您有很長的輸入序列,例如數千個時間步長,則可能需要將長輸入序列分成多個連續的子序列。

這將需要在Keras中使用有狀態LSTM,以便在子序列的輸入上保持內部狀態,並且僅在真正更充分的輸入序列的末尾處重置。

例如,如果您有100個輸入序列的50,000個步驟,則每個輸入序列可以分為100個子步驟,500個步驟。一個輸入序列將變為100個樣本,因此100個原始樣本將變為10,000。Keras輸入的維數為10,000個樣本,500個步驟和1個特徵或[10000,500,1]。需要注意保持每100個子序列的狀態,並在每100個樣本之後明確地或通過使用100的批量大小重置內部狀態。

將整個序列整齊地劃分為固定大小的子序列的劃分是優選的。全序列因子(子序列長度)的選擇是任意的,因此名稱為“天真資料分裂”。

將序列分成子序列不考慮關於用於估計用於更新權重的誤差梯度的合適數量的時間步的域資訊。

3.特定於域的資料拆分

可能很難知道提供錯誤梯度的有用估計所需的正確時間步數。

我們可以使用天真的方法(上面)快速獲得模型,但模型可能遠未優化。

或者,我們可以使用特定於域的資訊來估計在學習問題時與模型相關的時間步數。

例如,如果序列問題是迴歸時間序列,則可能對自相關和部分自相關圖的檢查可以通知選擇時間步數。

如果序列問題是自然語言處理問題,則輸入序列可以按句子劃分,然後填充到固定長度,或者根據域中的平均句子長度進行劃分。

廣泛思考並考慮您可以使用哪些特定於您的域的知識將序列拆分為有意義的塊。

4.系統資料拆分(例如網格搜尋)

您可以系統地為序列預測問題評估一組不同的子序列長度,而不是猜測適當數量的時間步長。

您可以對每個子序列長度執行網格搜尋,並採用導致平均效能最佳的模型的配置。

如果您正在考慮這種方法,請注意一些注意事項:

  • 從作為整個序列長度因子的子序列長度開始。
  • 如果探索不是整個序列長度因子的子序列長度,則使用填充和掩蔽。
  • 考慮使用略微過度規定的網路(更多的儲存單元和更多的訓練時期)來解決問題,以幫助排除網路容量作為實驗的限制。
  • 獲取每個不同配置的多次執行(例如30)的平均效能。

如果計算資源不是限制,則建議對不同時間步數進行系統調查。

5.使用TBPTT嚴重依靠內部狀態(1,1)

您可以將序列預測問題重新表述為每個時間步一個輸入和一個輸出。

例如,如果您有100個50次步長的序列,則每個時間步長將成為新的樣本。100個樣本將變為5,000。三維輸入將變為5,000個樣本,1個步驟和1個特徵,或[5000,1,1]。

同樣,這將要求在序列的每個時間步長內保留內部狀態,並在每個實際序列的末尾重置(50個樣本)。

這將把學習序列預測問題的負擔放在遞迴神經網路的內部狀態上。根據問題的型別,它可能不僅僅是網路可以處理的,而且預測問題可能無法學習。

個人經驗表明,這種表述可能適用於需要對序列進行記憶的預測問題,但是當結果是過去觀察的複雜函式時表現不佳。

6.解耦前向和後向序列長度

Keras深度學習庫用於支援通過時間截斷反向傳播的前向和後向傳遞的解耦的時間步長數。

本質上,k1引數可以通過輸入序列上的時間步數來指定,並且k2引數可以通過LSTM層上的“truncate_gradient”引數來指定。

這不再受支援,但有一些願望將此功能重新新增到庫中。雖然有證據表明它是出於效率原因完成的,但目前還不清楚為何被刪除

你可以在Keras探索這種方法。一些想法包括:

  • 安裝並使用支援“truncate_gradient”引數的舊版Keras庫(大約2015年)。
  • 在Keras中擴充套件LSTM層實現以支援“truncate_gradient”型別行為。

也許有可用於Keras的第三方擴充套件支援此行為。