1. 程式人生 > >【NLP】Attention原理和原始碼解析

【NLP】Attention原理和原始碼解析

對attention一直停留在淺層的理解,看了幾篇介紹思想及原理的文章,也沒實踐過,今天立個Flag,一天深入原理和原始碼!如果你也是處於attention model level one的狀態,那不妨好好看一下啦。


內容:

  1. 核心思想
  2. 原理解析(圖解+公式)
  3. 模型分類
  4. 優缺點
  5. TF原始碼解析

P.S. 拒絕長篇大論,適合有基礎的同學快速深入attention,不明白的地方請留言諮詢~

1. 核心思想

Attention的思想理解起來比較容易,就是在decoding階段對input中的資訊賦予不同權重。在nlp中就是針對sequence的每個time step input,在cv中就是針對每個pixel。

2. 原理解析

針對Seq2seq翻譯來說,rnn-based model差不多是圖1的樣子:

圖1 傳統rnn-based model

而比較基礎的加入attention與rnn結合的model是下面的樣子(也叫soft attention):

其中 \alpha_{0}^1h_{0}^1 對應的權重,算出所有權重後會進行softmax和加權,得到 c^0

可以看到Encoding和decoding階段仍然是rnn,但是decoding階使用attention的輸出結果 c^0, c^1 作為rnn的輸入。

那麼重點來了, 權重 \alpha 是怎麼來的呢?常見有三種方法:

  • \alpha_{0}^1=cos\_sim(z_0, h_1)
  • \alpha_0 =neural\_network(z_0, h)
  • \alpha_0 = h^TWz_0

思想就是根據當前解碼“狀態”判斷輸入序列的權重分佈。


如果把attention剝離出來去看的話,其實是以下的機制:

輸入是query(Q), key(K), value(V),輸出是attention value。如果與之前的模型對應起來的話,query就是 z_0, z_1 ,key就是 h_1, h_2, h_3, h_4 ,value也是h_1, h_2, h_3, h_4。模型通過Q和K的匹配計算出權重,再結合V得到輸出:

Attention(Q, K, V) = softmax(sim(Q, K))V \\

再深入理解下去,這種機制其實做的是定址(addressing),也就是模仿中央處理器與儲存互動的方式將儲存的內容讀出來,可以看一下李巨集毅老師的課程

3. 模型分類

3.1 Soft/Hard Attention

soft attention:傳統attention,可被嵌入到模型中去進行訓練並傳播梯度

hard attention:不計算所有輸出,依據概率對encoder的輸出取樣,在反向傳播時需採用蒙特卡洛進行梯度估計

3.2 Global/Local Attention

global attention:傳統attention,對所有encoder輸出進行計算

local attention:介於soft和hard之間,會預測一個位置並選取一個視窗進行計算

3.3 Self Attention

傳統attention是計算Q和K之間的依賴關係,而self attention則分別計算Q和K自身的依賴關係。具體的詳解會在下篇文章給出~

4. 優缺點

優點:

  • 在輸出序列與輸入序列“順序”不同的情況下表現較好,如翻譯、閱讀理解
  • 相比RNN可以編碼更長的序列資訊

缺點:

  • 對序列順序不敏感
  • 通常和RNN結合使用,不能並行化

5. TF原始碼解析

發現已經有人解析得很明白了,即使TF程式碼有更新,原理應該還是差不多的,直接放上來吧:

顧秀森:Tensorflow原始碼解讀(一):AttentionSeq2Seq模型zhuanlan.zhihu.com圖示


【參考資料】:

  1. 李巨集毅老師的課程
  2. 知乎:目前主流的attention方法都有哪些?
  3. 模型彙總24 - 深度學習中Attention Mechanism詳細介紹:原理、分類及應用