一種基於均值不等式的Listwise損失函式
阿新 • • 發佈:2020-10-06
## 一種基於均值不等式的Listwise損失函式
### 1 前言
#### 1.1 Learning to Rank 簡介
Learning to Rank (LTR) , 也被叫做排序學習, 是搜尋中的重要技術, 其目的是根據候選文件和查詢語句的相關性對候選文件進行排序, 或者選取topk文件. 比如在搜尋引擎中, 需要根據使用者問題選取最相關的搜尋結果展示到首頁. 下圖是搜尋引擎的搜尋結果
![search_result.jpg](https://i.loli.net/2020/10/06/YnmJTPO9BSGhtK5.jpg)
#### 1.2 LTR演算法分類
根據損失函式可把LTR分為三種:
1. **Pointwise**, 該型別演算法將LTR任務作為迴歸任務來訓練, 即嘗試訓練一個為文件和查詢語句的打分器, 然後根據打分進行排序.
2. **Pairwise**, 該型別演算法的損失函式考慮了兩個候選文件, 學習目標是把相關性高的文件排在前面, **triplet loss** 就屬於**Pairwise**, 它的損失函式是
$$loss = max(0, score_{neg}-score_{pos}+margin)$$
可以看出該損失函式一次考慮兩個候選文件.
3. **Listwise**, 該型別演算法的損失函式會考慮多個候選文件, 這是本文的重點, 下面會詳細介紹.
#### 1.3 本文主要內容
本文主要介紹了本人在學習研究過程中發明的一種新的**Listwise**損失函式, 以及該損失函式的使用效果. 如果讀者對LTR任務及其演算法還不夠熟悉, 建議先去學習LTR相關知識, 同時本人博文[自然語言處理中的負樣本挖掘 (分類與排序任務中如何選擇負樣本)](https://www.cnblogs.com/infgrad/p/13664315.html) 也和本文關係較大, 可以先進行閱讀.
### 2 預備知識
#### 2.1 數學符號定義
$q$代表使用者搜尋問題, 比如"如何成為宇航員", $D$代表候選文件集合,$d^+$代表和$q$相關的文件,$d^-$代表和$q$不相關的文件, $d^+_i$代表第$i$個和$q$相關的文件, LTR的目標就是根據$q$找到最相關的文件$d$
#### 2.2 學習目標
本次學習目標是訓練一個打分器 scorer, 它可以衡量q和d的相關性, $scorer(q, d)$就是相關性分數,分值越大越相關. 當前主流方法下, scorer一般選用深度神經網路模型.
#### 2.3訓練資料分類
損失函式不同, 構造訓練資料的方法也會不同:
-**Pointwise**, 可以構造迴歸資料集, 相關的資料設為1, 不相關設為0.
-**Pairwise**, 可構造triplet型別的資料集, 形如($q,d^+, d^-$)
-**Listwise**, 可構造這種型別的訓練集: ($q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}$), 一個正例還是多個正例也會影響到損失函式的構造, 本文提出的損失函式是針對多正例多負例的情況.
### 3 基於均值不等式的Listwise損失函式
#### 3.1 損失函式推導過程
在上一小結我們可以知道,訓練集是如下形式 ($q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}$), 對於一個q, 有n個相關的文件和m個不相關的文件, 那麼我們一共可以獲取m+n個分值:$(score_1,score_2,...,score_n,...,score_{n+m})$, 我們希望打分器對相關文件打分趨近於正無窮, 對不相關文件打分趨近於負無窮.
對m+n個分值做一個softmax得到$p_1,p_2,...,p_n,...,p_{n+m}$, 此時$p_i$可以看作是第i個候選文件與q相關的概率, 顯然我們希望$p_1,p_2,...,p_n$越大越好, $p_{n+1},...,p_{m+n}$越小越好, 即趨近於0. 因此我們暫時的優化目標是$\sum_{i=1}^{n}{p_i} \rightarrow 1$.
但是這個優化目標是不合理的, 假設$p_1=1$, 其他值全為0, 雖然滿足了上面的要求, 但這並不是我們想要的. 因為我們不僅希望$\sum_{i=1}^{n}{p_i} \rightarrow 1$, 還希望相關候選文件的每一個p值都要足夠大, 即**我們希望: n個候選文件都與q相關的概率是最大的**, 所以我們真正的優化目標是:
$$\max(\prod_{i=1}^{n}{p_i} ) , \sum_{i=1}^{n}{p_i} = 1$$
當前情況下, 損失函式已經可以通過程式碼實現了, 但是我們還可以做一些化簡工作, $\prod_{i=1}^{n}{p_i}$是存在最大值的, 根據均值不等式可得:
$$\prod_{i=1}^{n}{p_i} \leq (\frac{\sum_{i=1}^{n}{p_i}}{n})^n$$
對兩邊取對數:
$$\sum_{i=1}^{n}{log(p_i)} \leq -nlog(n)$$
這樣是不是感覺清爽多了, 然後我們把它轉換成損失函式的形式:
$$ loss = -nlog(n) - \sum_{i=1}^{n}{log(p_i)}$$
所以我們的訓練目標就是$\min{(loss)}$
#### 3.2 使用pytorch實現該損失函式
在獲取到最終的損失函式後, 我們還需要用程式碼來實現, 實現程式碼如下:
```python
# A simple example for my listwise loss function
# Assuming that n=3, m=4
# In[1]
# scores
scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]])
print(scores)
print(scores.shape)
'''
tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]])
torch.Size([1, 7])
'''
# In[2]
# log softmax
log_prob = torch.nn.functional.log_softmax(scores,dim=1)
print(log_prob)
'''
tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]])
'''
# In[3]
# compute loss
n = 3.
mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n
loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True)
print(loss)
loss = loss.mean()
print(loss)
'''
tensor([[1.2261]])
tensor(1.2261)
'''
```
該示例程式碼僅展現了batch_size為1的情況, 在batch_size大於1時, 每一條資料都有不同的m和n, 為了能一起送入模型計算分值, 需要靈活的使用mask. 本人在實際使用該損失函式時,一共使用了兩種mask, 分別mask每條資料所有候選文件和每條資料的相關文件, 供大家參考使用.
#### 3.3 效果評估和使用經驗
由於評測資料使用的是內部資料, 程式碼和資料都無法公開, 因此只能對使用效果做簡單總結:
1. 效果優於**Pointwise**和**Pairwise**, 但差距不是特別大
2. 相比**Pairwise**收斂速度極快, 訓練一輪基本就可以達到最佳效果
下面是個人使用經驗:
1. 該損失函式比較佔用視訊記憶體, 實際的batch_size是batch_size*(m+n), 建議視訊記憶體在12G以上
2. 負例數量越多,效果越好, 收斂也越快
3. 用pytorch實現log_softmax時, 不要自己實現, 直接使用torch中的log_softmax函式, 它的效率更高些.
4. 只有一個正例, 還可以考慮轉為分類問題,使用交叉熵做優化, 效果同樣較好
### 4 總結
該損失函式還是比較簡單的, 只需要簡單的數學知識就可以自行推導, 在實際使用中也取得了較好的效果, 希望也能夠幫助到大家. 如果大家有更好的做法歡迎告訴我.
**文章可以轉載, 但請註明出處:**
- [本人簡書社群主頁](https://www.jianshu.com/u/2609876244a6)
- [本人部落格園社群主頁](https://home.cnblogs.com/u/infgrad/)
- [本人知乎主頁](https://www.zhihu.com/people/zdd-44-59)
- [本人Medium社群主頁](https://medium.com/@dunnzh