NIPS 2018:作為多目標優化的多工學習:尋找帕累托最優解
多工學習本質上是一個多目標問題,因為不同任務之間可能產生衝突,需要對其進行取捨。本文明確將多工學習視為多目標優化問題,以尋求帕累托最優解。而經過實驗證明,本文提出的方法可以在現實假設下得到帕累托最優解。
統計學中最令人震驚的結論之一是 Stein 悖論。Stein(1956)認為,若要估計高斯隨機變數,最好是從所有樣本中估計三個或三個以上變數的均值,而不是分別單獨進行估計,即使這些高斯分佈是相互獨立的。Stein 悖論是探索多工學習(MTL)(Caruana,1997)的早期動機。多工學習是一種學習正規化,其中來自多工的資料被用來獲得優於獨立學習每個任務的效能。MTL 的潛在優勢超出了 Stein 悖論的直接含義,因為即便是真實世界中看似無關的任務也因資料共享的過程而存在很強的依賴性。例如,儘管自動駕駛和目標操縱看似無關,但相同的光學規律、材料屬性以及動力學都對基礎資料產生了影響。這啟發人們在學習系統中使用多工作為歸納偏好。
典型的 MTL 系統被給定一組輸入點和每點各種任務的目標集。設定跨任務的歸納偏好的常用方法是設計一個引數化假設類,它會在不同任務中共享一些引數。一般而言,可以通過為每個任務最小化經驗風險的加權和這種優化問題來學習這些引數。但是,只有當一個引數組在所有任務中都有效時,這樣的線性組合公式才有意義。換言之,只有當任務之間不存在競爭關係時,最小化經驗風險的加權和才有效,但這種情況比較少有。目標衝突的 MTL 需要對任務之間的權衡進行模型,但這已經超出了線性組合能夠實現的範圍。
MTL 的另一個目標是找到不受任何其它方案主導的解決方案。據說這種方案就是帕累托最優(Pareto optimal)。本文從尋找帕累托最優解的角度出發探尋 MTL 的目標。
在給定多個標準的情況下,尋找帕累托最優解的問題也被稱為多目標優化。目前已有多種多目標優化演算法,其中一種叫多梯度下降演算法(MGDA),使用基於梯度的優化,證明了帕累託集合上的點是收斂的(Désidéri,2012)。MGDA 非常時候具有深層網路的多工學習。它可以用每個任務的梯度解決優化問題來更新共享引數。但有兩個技術性的問題阻礙了 MGDA 的大規模應用。(i)基本的優化問題無法擴充套件到高維度梯度,而後者會自然出現在深度網路中。(ii)該演算法要求明確計算每個任務的梯度,這就導致反向迭代的次數會被線性縮放,訓練時間大致會乘以任務數量。
我們在本文中開發了基於 Frank-Wolfe 且可以擴充套件到高維問題的優化器。此外,我們還給 MGDA 優化目標提供了一個上界,並表明可以在沒有明確特定任務梯度的情況下通過單次反向迭代來計算該優化目標,這使得該方法的計算成本小到可以忽略不計。本文證明,用我們的上界可以在現實假設情況下得到帕累托最優解。最終我們得到了一個針對深度網路多目標優化問題的精確演算法,計算開銷可以忽略不計。
我們在三個不同的問題上對提出的方法進行了實證評估。首先,我們在 MultiMNIST(Sabour 等人,2017)上做了多數字分類的延伸評估。其次,我們將多標籤分類作為 MTL,並在 CelebA 資料集(Liu 等人,2015b)上進行了實驗。最後,我們將本文所述方法應用於場景理解問題中。具體而言,我們在 Cityscapes 資料集(Cordts 等人,2016)上做了聯合語義分割、例項分割以及深度估計。在我們的評估中,任務數量從 2 到 40 不等。我們的方法明顯優於所有基線。
論文:Multi-Task Learning as Multi-Objective Optimization
論文連結:https://arxiv.org/pdf/1810.04650v1.pdf
摘要:在多工學習中,要聯合處理多個任務,且這些任務間共享歸納偏好。多工學習本質上是一個多目標問題,因為不同任務之間可能會產生衝突,需要進行取捨。常用的折中方法是優化代理物件(proxy objective),該物件使每個任務損失的加權線性組合最小化。但這種方法只有在任務間不存在競爭關係時才有效,而這種情況是很少發生的。在本文中,我們明確將多工學習視為多目標優化,最終目標是找到帕累托最優解。為此,本文使用了基於梯度的多目標優化文獻中開發的演算法。這些演算法不能直接應用於大規模學習問題中,因為隨著梯度維度和任務數量的增加,演算法結果會變得很差。因此我們為多目標損失提出了一個上界,實驗結果表明這樣可以有效對其進行優化。本文進一步證明,對這個上界進行優化可以在現實假設下得到帕累托最優解。我們將本文提出的方法應用於各種多工深度學習任務中,包括數字分類、場景理解(聯合語義分割、例項分割和深度估計)以及多標籤分類,結果表明該方法產生的模型效能比多工學習公式或單任務訓練產生的模型效能更好。
圖 2:在 CelebA(Liu 等人,2015b)上得到的每個屬性的百分比誤差雷達圖。該值越低越好。為了方便閱讀,將這些屬性分為兩組:左圖容易一些,右圖難一些。可放大檢視細節。
表 1:CelebA(Liu 等人,2015b)上多標籤分類中每類 MTL 演算法的誤差均值。
表 2:MGDA-UB 近似的效果。在使用和不使用該近似的情況下本文所述模型的最終準確率和訓練時間。
圖 3:MultiMNIST準確率概況。我們繪製了在檢測所有基線左右兩邊數字時獲得的準確率。網格搜尋結果表明不同任務在競爭模型容量。我們的方法得到的結果和為每個任務單獨訓練模型的效果一樣好。右上角區域表示結果更好。
表 4:Cityscapes 效能概況。我們繪製了針對語義分割、例項分割和深度估計等任務的所有基線的效能。儘管繪製了視覺化的成對投影,但是圖中的每個點都處理了所有任務。右上角區域表示結果更好。