1. 程式人生 > >論文筆記:Self-critical Sequence Training for Image Captioning

論文筆記:Self-critical Sequence Training for Image Captioning

引言

現在image caption主要存在的問題有:

  1. exposure bias:模型訓練的時候用的是叫“Teacher-Forcing”的方式:輸入RNN的上一時刻的單詞是來自訓練集的ground-truth單詞。而在測試的時候依賴的是自己生成的單詞,一旦生成得不好就會導致誤差的積累,導致後面的單詞也生成得不好。

  2. 模型訓練的時候用的是cross entropy loss,而evaluate的時候卻用的是BLEU、ROUGE、METEOR、CIDEr等metrics,存在不對應的問題。

由於生成單詞的操作是不可微的,無法通過反向傳播來直接優化這些metrics,因此很多工作開始使用強化學習來解決這些問題。

但強化學習在計算期望梯度時的方差會很大,通常來說是不穩定的。又有些研究通過引入一個baseline來進行bias correction。還有一些方法比如Actor-Critic,訓練了一個critic網路來估算生成單詞的value。

而本文的方法則沒有直接去估算reward,而是使用了自己在測試時生成的句子作為baseline。sample時,那些比baseline好的句子就會獲得正的權重,差的句子就會被抑制。具體做法會在後面展開。

Caption Models

本文分別使用了兩個caption model作為基礎,分別是

  1. FC model [3, 4]。就是以cross entropy loss訓練的image caption模型,公式基本跟
    show and tell
    裡的公式一樣。
  2. Attention Model(Att2in)。本文對原模型的結構進行了一些修改,只把attention feature輸入到LSTM的cell node,並且發現使用ADAM方法優化的時候,這種結構的表現優於其他結構。

FC models

公式就不打了,只需要知道LSTM最後輸出的是每個單詞的分數st,再通過softmax得到下一個單詞的概率分佈為wt

訓練目標是最小化cross entropy loss(XE):

L(θ)=t=1Tlog(pθ(w

t|w1,...,wt1))

θ是模型的引數,w1,...,wt1是訓練集中的語句,stpθ在後面Reinforcement Learning部分會被用到。

Attention Model(Att2in)

公式基本跟FC Model的一樣,只不過在cell node的公式里加了個attention項,其他部分以及loss function也跟上面一樣的。

Reinforcement Learning

把序列問題看作是一個RL的問題:

  • Agent: LSTM
  • Environment: words and image features
  • Action: prediction of the next word(模型的引數θ定義了一個policy pθ,也就是上面的pθ,從而導致了這個action)
  • State: cells and hidden states of the LSTM, attenion weights etc
  • Reward: CIDEr score r

訓練目標是最小化負的期望

L(θ)=Ewspθ[r(ws)]

=r(ws)pθ(ws)

ws=(w1s,...,wTs)是生成的句子。

實際上,ws可以依據pθ的概率來進行single sample(而不是選擇概率最大的那一個),L(θ)可以近似為:

L(θ)r(ws),wspθ

L關於θ的梯度為:

θL(θ)=Ewspθ[r(ws)θlogpθ(ws)]

推導過程:

這裡寫圖片描述
再引入一個baseline來減少方差: