1. 程式人生 > >一種基於均值不等式的Listwise損失函式

一種基於均值不等式的Listwise損失函式

## 一種基於均值不等式的Listwise損失函式 ### 1 前言 #### 1.1 Learning to Rank 簡介 Learning to Rank (LTR) , 也被叫做排序學習, 是搜尋中的重要技術, 其目的是根據候選文件和查詢語句的相關性對候選文件進行排序, 或者選取topk文件. 比如在搜尋引擎中, 需要根據使用者問題選取最相關的搜尋結果展示到首頁. 下圖是搜尋引擎的搜尋結果 ![search_result.jpg](https://i.loli.net/2020/10/06/YnmJTPO9BSGhtK5.jpg) #### 1.2 LTR演算法分類 根據損失函式可把LTR分為三種: 1. **Pointwise**, 該型別演算法將LTR任務作為迴歸任務來訓練, 即嘗試訓練一個為文件和查詢語句的打分器, 然後根據打分進行排序. 2. **Pairwise**, 該型別演算法的損失函式考慮了兩個候選文件, 學習目標是把相關性高的文件排在前面, **triplet loss** 就屬於**Pairwise**, 它的損失函式是 $$loss = max(0, score_{neg}-score_{pos}+margin)$$ 可以看出該損失函式一次考慮兩個候選文件. 3. **Listwise**, 該型別演算法的損失函式會考慮多個候選文件, 這是本文的重點, 下面會詳細介紹. #### 1.3 本文主要內容 本文主要介紹了本人在學習研究過程中發明的一種新的**Listwise**損失函式, 以及該損失函式的使用效果. 如果讀者對LTR任務及其演算法還不夠熟悉, 建議先去學習LTR相關知識, 同時本人博文[自然語言處理中的負樣本挖掘 (分類與排序任務中如何選擇負樣本)](https://www.cnblogs.com/infgrad/p/13664315.html) 也和本文關係較大, 可以先進行閱讀. ### 2 預備知識 #### 2.1 數學符號定義 $q$代表使用者搜尋問題, 比如"如何成為宇航員", $D$代表候選文件集合,$d^+$代表和$q$相關的文件,$d^-$代表和$q$不相關的文件, $d^+_i$代表第$i$個和$q$相關的文件, LTR的目標就是根據$q$找到最相關的文件$d$ #### 2.2 學習目標 本次學習目標是訓練一個打分器 scorer, 它可以衡量q和d的相關性, $scorer(q, d)$就是相關性分數,分值越大越相關. 當前主流方法下, scorer一般選用深度神經網路模型. #### 2.3訓練資料分類 損失函式不同, 構造訓練資料的方法也會不同: -**Pointwise**, 可以構造迴歸資料集, 相關的資料設為1, 不相關設為0. -**Pairwise**, 可構造triplet型別的資料集, 形如($q,d^+, d^-$) -**Listwise**, 可構造這種型別的訓練集: ($q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}$), 一個正例還是多個正例也會影響到損失函式的構造, 本文提出的損失函式是針對多正例多負例的情況. ### 3 基於均值不等式的Listwise損失函式 #### 3.1 損失函式推導過程 在上一小結我們可以知道,訓練集是如下形式 ($q,d^+_1,d^+_2..., d^+_n , d^-_1, d^-_2, ..., d^-_{n+m}$), 對於一個q, 有n個相關的文件和m個不相關的文件, 那麼我們一共可以獲取m+n個分值:$(score_1,score_2,...,score_n,...,score_{n+m})$, 我們希望打分器對相關文件打分趨近於正無窮, 對不相關文件打分趨近於負無窮. 對m+n個分值做一個softmax得到$p_1,p_2,...,p_n,...,p_{n+m}$, 此時$p_i$可以看作是第i個候選文件與q相關的概率, 顯然我們希望$p_1,p_2,...,p_n$越大越好, $p_{n+1},...,p_{m+n}$越小越好, 即趨近於0. 因此我們暫時的優化目標是$\sum_{i=1}^{n}{p_i} \rightarrow 1$. 但是這個優化目標是不合理的, 假設$p_1=1$, 其他值全為0, 雖然滿足了上面的要求, 但這並不是我們想要的. 因為我們不僅希望$\sum_{i=1}^{n}{p_i} \rightarrow 1$, 還希望相關候選文件的每一個p值都要足夠大, 即**我們希望: n個候選文件都與q相關的概率是最大的**, 所以我們真正的優化目標是: $$\max(\prod_{i=1}^{n}{p_i} ) , \sum_{i=1}^{n}{p_i} = 1$$ 當前情況下, 損失函式已經可以通過程式碼實現了, 但是我們還可以做一些化簡工作, $\prod_{i=1}^{n}{p_i}$是存在最大值的, 根據均值不等式可得: $$\prod_{i=1}^{n}{p_i} \leq (\frac{\sum_{i=1}^{n}{p_i}}{n})^n$$ 對兩邊取對數: $$\sum_{i=1}^{n}{log(p_i)} \leq -nlog(n)$$ 這樣是不是感覺清爽多了, 然後我們把它轉換成損失函式的形式: $$ loss = -nlog(n) - \sum_{i=1}^{n}{log(p_i)}$$ 所以我們的訓練目標就是$\min{(loss)}$ #### 3.2 使用pytorch實現該損失函式 在獲取到最終的損失函式後, 我們還需要用程式碼來實現, 實現程式碼如下: ```python # A simple example for my listwise loss function # Assuming that n=3, m=4 # In[1] # scores scores = torch.tensor([[3,4.3,5.3,0.5,0.25,0.25,1]]) print(scores) print(scores.shape) ''' tensor([[0.3000, 0.3000, 0.3000, 0.0250, 0.0250, 0.0250, 0.0250]]) torch.Size([1, 7]) ''' # In[2] # log softmax log_prob = torch.nn.functional.log_softmax(scores,dim=1) print(log_prob) ''' tensor([[-2.7073, -1.4073, -0.4073, -5.2073, -5.4573, -5.4573, -4.7073]]) ''' # In[3] # compute loss n = 3. mask = torch.tensor([[1,1,1,0,0,0,0]]) # number of 1 is n loss = -1*n*torch.log(torch.tensor([[n]])) - torch.sum(log_prob*mask,dim=1,keepdim=True) print(loss) loss = loss.mean() print(loss) ''' tensor([[1.2261]]) tensor(1.2261) ''' ``` 該示例程式碼僅展現了batch_size為1的情況, 在batch_size大於1時, 每一條資料都有不同的m和n, 為了能一起送入模型計算分值, 需要靈活的使用mask. 本人在實際使用該損失函式時,一共使用了兩種mask, 分別mask每條資料所有候選文件和每條資料的相關文件, 供大家參考使用. #### 3.3 效果評估和使用經驗 由於評測資料使用的是內部資料, 程式碼和資料都無法公開, 因此只能對使用效果做簡單總結: 1. 效果優於**Pointwise**和**Pairwise**, 但差距不是特別大 2. 相比**Pairwise**收斂速度極快, 訓練一輪基本就可以達到最佳效果 下面是個人使用經驗: 1. 該損失函式比較佔用視訊記憶體, 實際的batch_size是batch_size*(m+n), 建議視訊記憶體在12G以上 2. 負例數量越多,效果越好, 收斂也越快 3. 用pytorch實現log_softmax時, 不要自己實現, 直接使用torch中的log_softmax函式, 它的效率更高些. 4. 只有一個正例, 還可以考慮轉為分類問題,使用交叉熵做優化, 效果同樣較好 ### 4 總結 該損失函式還是比較簡單的, 只需要簡單的數學知識就可以自行推導, 在實際使用中也取得了較好的效果, 希望也能夠幫助到大家. 如果大家有更好的做法歡迎告訴我. **文章可以轉載, 但請註明出處:** - [本人簡書社群主頁](https://www.jianshu.com/u/2609876244a6) - [本人部落格園社群主頁](https://home.cnblogs.com/u/infgrad/) - [本人知乎主頁](https://www.zhihu.com/people/zdd-44-59) - [本人Medium社群主頁](https://medium.com/@dunnzh