論文筆記:Self-critical Sequence Training for Image Captioning
引言
現在image caption主要存在的問題有:
exposure bias:模型訓練的時候用的是叫“Teacher-Forcing”的方式:輸入RNN的上一時刻的單詞是來自訓練集的ground-truth單詞。而在測試的時候依賴的是自己生成的單詞,一旦生成得不好就會導致誤差的積累,導致後面的單詞也生成得不好。
模型訓練的時候用的是cross entropy loss,而evaluate的時候卻用的是BLEU、ROUGE、METEOR、CIDEr等metrics,存在不對應的問題。
由於生成單詞的操作是不可微的,無法通過反向傳播來直接優化這些metrics,因此很多工作開始使用強化學習來解決這些問題。
但強化學習在計算期望梯度時的方差會很大,通常來說是不穩定的。又有些研究通過引入一個baseline來進行bias correction。還有一些方法比如Actor-Critic,訓練了一個critic網路來估算生成單詞的value。
而本文的方法則沒有直接去估算reward,而是使用了自己在測試時生成的句子作為baseline。sample時,那些比baseline好的句子就會獲得正的權重,差的句子就會被抑制。具體做法會在後面展開。
Caption Models
本文分別使用了兩個caption model作為基礎,分別是
- FC model [3, 4]。就是以cross entropy loss訓練的image caption模型,公式基本跟
- Attention Model(Att2in)。本文對原模型的結構進行了一些修改,只把attention feature輸入到LSTM的cell node,並且發現使用ADAM方法優化的時候,這種結構的表現優於其他結構。
FC models
公式就不打了,只需要知道LSTM最後輸出的是每個單詞的分數,再通過softmax得到下一個單詞的概率分佈為。
訓練目標是最小化cross entropy loss(XE):
是模型的引數,是訓練集中的語句,、在後面Reinforcement Learning部分會被用到。
Attention Model(Att2in)
公式基本跟FC Model的一樣,只不過在cell node的公式里加了個attention項,其他部分以及loss function也跟上面一樣的。
Reinforcement Learning
把序列問題看作是一個RL的問題:
- Agent: LSTM
- Environment: words and image features
- Action: prediction of the next word(模型的引數定義了一個policy ,也就是上面的,從而導致了這個action)
- State: cells and hidden states of the LSTM, attenion weights etc
- Reward: CIDEr score r
訓練目標是最小化負的期望
是生成的句子。
實際上,可以依據的概率來進行single sample(而不是選擇概率最大的那一個),可以近似為:
L關於的梯度為:
推導過程:
再引入一個baseline來減少方差: