這份攻略幫你「穩住」反覆無常的 GAN
生成對抗網路 (GAN)是一類非常強大的神經網路,具有非常廣闊的應用前景。GAN 本質上是由兩個相互競爭的神經網路(生成器和判別器)組成的系統。
GAN 的工作流程示意圖。
給定一組目標樣本,生成器會試圖生成一些人造的樣本,這些生成的樣本能夠欺騙判別器將其視為真實的目標樣本,達到「以假亂真」的目的。而判別器則會試圖將真實的(目標)樣本與虛假的(生成)樣本區分開來。通過這樣迴圈往復的訓練方法,我們最終可以得到一個能夠很好地生成與目標樣本相似的樣本的生成器。
由於 GAN 幾乎可以學會模擬出所有型別的資料分佈,它有著非常廣泛的應用場景。通常,GAN 被用來去除圖片中的人為影響、超解析度、姿勢遷移以及任何型別的影象轉換,如下所示:
使用 GAN 完成的影象變換。
然而,由於 GAN 的訓練穩定性反覆無常,使用 GAN 是十分困難的。誠然,許多研究人員已經提出了很好的解決方案來緩解 GAN 訓練中涉及的一些問題。然而,這一領域的研究進展是如此之快,以至於人們很難跟上這些最新的有趣的想法。本文列出了一些常用的使 GAN 訓練穩定的技術。
使用 GAN 的弊端
由於一系列原因,想要使用 GAN 是十分困難的。本節列舉出了其中的一些原因:
1. 模式崩潰
自然的資料分佈是極其複雜的多模態函式(也稱多峰函式)。也就是說,資料分佈有許多「峰」或「模式」。每個模態代表相似的資料樣本聚集在一起,但是與其它的模態並不相同。
在模式崩潰的情況下,生成器會生成從屬於有限模態集集合的樣本。當生成器認為它可以通過生成單一模式的樣本來欺騙鑑別器時,就會發生這種情況。也就是說,生成器只從這種模式生成樣本。
上面一排圖片表示沒有發生模式崩潰的情況下 GAN 輸出的樣本。下面一排圖片表示發生模式崩潰時 GAN 輸出的樣本。
判別器最終會發現這種模式是人為生成的。結果,生成器會直接轉而生成另一種模式。這樣的情況會無限迴圈下去,從本質上限制了生成樣本的多樣性。詳細解釋請參考部落格《Mode collapse in GANs》(http://aiden.nibali.org/blog/2017-01-18-mode-collapse-gans/)
2. 收斂性
在 GAN 的訓練過程中,一個普遍的問題就是「何時停止訓練 GAN 模型?」由於在判別器損失降低的同時生成器的損失會增高(反之亦然),我們並不能基於損失函式的值就來判別 GAN 的收斂性。下圖說明了這一點:
一張典型的 GAN 損失函式示意圖。請注意,此圖無法說明 GAN 的收斂性。
3. 質量
和前面提到的問題一樣,很難定量地判斷生成器何時能生成高質量的樣本。向損失函式中加入額外的感知正則化項可以在一定程度上幫助我們緩解這種情況。
4. 評價標準
GAN 的目標函式說明了生成器(G)與判別器(D)這一對相互博弈的模型相對於其對手的效能,但卻不能代表輸出樣本的質量或多樣性。因此,我們需要能夠在目標函式相同的情況下進行度量的獨特的評價標準。
術語
在我們深入研究可能有助於提升 GAN 模型效能的技術之前,讓我們回顧一些術語。
1. 下確界及上確界
簡而言之,下確界是集合的最大下界,上確界是集合的最小上界,「上確界、下確界」與「最小值、最大值」的區別在於下確界和上確界不一定屬於集合。
2. 散度度量
散度度量代表了兩個分佈之間的距離。傳統的 GAN 本質上是最小化了真實資料分佈和生成的資料分佈之間的 Jensen Shannon 散度( JS 散度 )。GAN 的損失函式可以被改寫為最小化其它的散度度量,例如:Kulback Leibler 散度(KL 散度)或全變分距離。通常,Wasserstein GAN 最小化了推土機距離。
3. Kantorovich Rubenstein 對偶性
我們很難使用一些散度度量的原始形式進行優化。然而,它們的對偶形式(用上確界替換下確界,反之亦然)可能就較為容易優化。對偶原理為將一種形式轉化為另一種形式提供了框架。詳細解釋請參考部落格:《Wasserstein GAN and the Kantorovich-Rubinstein Duality》(https://vincentherrmann.github.io/blog/wasserstein/)
4. LiPSCHITZ 連續性
一個 Lipschitz 連續函式的變化速度是有限的。對具備 Lipschitz 連續性的函式來說,函式曲線上任一點的斜率的絕對值不能超過實數 K。這樣的函式也被稱為 K-Lipschitz 連續函式。
Lipschitz 連續性是 GAN 所期望滿足的,因為它們會限制判別器的梯度,從而從根本上避免了梯度爆炸問題。另外,Kantorovich-Rubinstein 對偶性要求 Wasserstein GAN 也滿足 Lipschitz 連續性,詳細解釋請參考部落格:《Spectral Normalization Explained》(https://christiancosgrove.com/blog/2018/01/04/spectral-normalization-explained.html)。
用於提升模型效能的技術
有許多技巧和技術可以被用來使 GAN 更加穩定和強大。為了保證本文的簡潔性,我僅僅解釋了一些相對來說較新或較為複雜的技術。在本節的末尾,我列舉出了其它的各種各樣的技巧和技術。
1. 替換 損失函式
針對 GAN 存在的的缺點,最流行的修正方法之一是使用「Wasserstein GAN」。它本質上是使用「推土機距離」(Wasserstein-1 距離或 EM 距離)代替傳統 GAN 的 Jensen Shannon 散度。然而,EM 距離的原始形式是難以進行優化的,因此我們使用它的對偶形式(通過 Kantorovich Rubenstein 對偶性計算得出)。這要求判別器滿足「1-Lipschitz」,我們是通過裁剪判別器的權重來保證這一點的。
使用推土機距離的優點是,即使真實的和生成的樣本的資料分佈沒有交集,推土機距離也是「連續的」,這與 JS 或 KL 散度不同。此外,此時生成影象的質量與損失函式值之間存在相關性。而使用推土機距離的缺點是,我們需要在每次更新生成器時更新好幾個判別器(對於原始實現的每次生成器更新也是如此)。此外,作者聲稱,權值裁剪是一種糟糕的確保 1-Lipschitz 約束的方法。
與 Jensen Shannon 散度(如右圖所示)不同,即使資料分佈不是連續的,推土機距離(如左圖所示)也是連續的。詳細的解釋請參閱論文《Wasserstein GAN》(https://arxiv.org/pdf/1701.07875.pdf)
另一種有趣的解決方案是採用均方損失而非對數損失。LSGAN 的作者認為,傳統的 GAN 損失函式並沒有提供足夠的刺激來「拉動」生成的資料分佈逼近真實的資料分佈。
原始 GAN 損失函式中的對數損失並不影響生成資料與決策邊界之間的距離(決策邊界將真實資料和生成的資料分開)。另一方面,LSGAN 對遠離決策邊界的生成樣本進行懲罰,本質上將生成的資料分佈「拉向」實際的資料分佈。它通過使用均方損失替代對數損失來做到這一點。詳細解釋請參考部落格:《 Least Squares GAN 》。
2. 兩個時間尺度上的更新規則(TTUR)
在此方法中,我們為判別器和生成器使用了不同的學習率。通常,生成器使用較慢的更新規則,而判別器使用較快的更新規則。通過使用這種方法,我們只需對學習率進行微調,就可以以 1:1 的比例執行生成器和判別器的更新。值得注意的是,SAGAN 的實現就使用了這個方法。
3. 梯度懲罰
在論文「Improved Training of WGANs」中,作者聲稱權值裁剪(正如在原始的 WGAN 中執行的那樣)導致一些優化問題的產生。作者認為權重裁剪迫使神經網路去學習「較為簡單的近似」從而得到最優的資料分佈,這導致 GAN 得到的最終結果質量變低。他們還聲稱,如果 WGAN 的超引數設定不正確,權重裁剪會導致梯度爆炸或梯度消失的問題。作者在損失函式中引入了一個簡單的梯度懲罰規則,從而緩解了上述問題。除此之外,正如在原始的 WGAN 實現中那樣,這樣做還保證了 1-Lipschitz 連續性。
正如在原始的 WGAN-GP 論文中提到的,將梯度懲罰作為正則化項加入。
DRAGAN 的作者聲稱,當 GAN 中進行的博弈(即判別器和生成器互相進行對抗)達到了「區域性均衡狀態」時,模式崩潰現象就會發生。他們還聲稱,此時由判別器所貢獻的梯度是非常「尖銳的」。使用這樣的梯度懲罰能夠很自然地幫助我們避開這些狀態,大大提高訓練的穩定性,並減少模式崩潰現象的發生。
4. 譜歸一化
譜歸一化是一種通常在判別器中使用的權值歸一化技術,它能夠優化訓練過程(使訓練過程更穩定),從本質上保證了判別器滿足「K-Lipschitz 連續性」。
SAGAN 等實現也在生成器中使用了譜歸一化技術。博文《 Spectral Normalization Explained 》也指出,譜歸一化比梯度懲罰的計算效率更高。
5. 展開和打包
正如博文《Mode collapse in GANs》所描述的,一個阻止模式崩潰發生的方法是在更新引數時預測「對策」。當判別器有機會對生成器的結果做出反應時(考慮到對策,就像 min-max 方法),展開(unrolled)的 GAN 就可以讓生成器騙過判別器。
另一個阻止模式崩潰發生的方式是將屬於同一類的一些樣本「打包」,然後將其傳給判別器。這種方法被 PacGAN 所採用,該論文聲稱它們減少了模式崩潰的發生。
6. 堆疊 GAN
單個的 GAN 可能不夠強大,無法有效地處理某些任務。因此,我們可以使用連續放置的多個 GAN,其中每個 GAN 可以解決一個簡化的問題模組。例如,FashionGAN 使用了兩個 GAN 處理區域性的影象轉換任務。
FashionGAN 使用了兩個 GAN 來執行區域性的影象轉換。
將這種情況推到極致,可以逐步提高 GAN 模型所面臨問題的難度。例如,Progressive GAN(ProGAN)可以生成解析度超高的高質量影象。
7. 相對 GAN
傳統的 GAN 會度量生成資料是真實資料的概率。相對 GAN(Relativistic GAN)則會去度量生成資料比真實資料「更加真實」的概率。正如 RGAN 相關論文《The relativistic discriminator: a key element missing from standard GAN》中提到的那樣,我們可以使用一個合適的距離來度量這種「相對真實性」。
圖 B 為我們使用標準 GAN 損失得到的判別器的輸出。圖 C 為輸出的曲線實際的樣子。圖 A 為 JS 散度的最優解。
作者還提到,當判別器達到最優狀態時,其輸出的概率 D(x)應該收斂到 0.5。然而,傳統的 GAN 訓練演算法會迫使判別器為任何影象輸出「真實」(即概率為 1)的結果。這在某種程度上阻止了判別器的輸出概率達到其最優值。相對 GAN 也解決了這個問題,並且如下圖所示,取得了非常顯著的效果。
在 5000 輪迭代後,標準 GAN 得到的輸出(左圖),以及相對 GAN 得到的輸出(右圖)。
8. 自注意力機制
自注意力 GAN 的作者聲稱,用於生成影象的卷積操作關注的是區域性傳播的資訊。也就是說,由於它們的感受野(restrictive receptive field)有限,它們忽略了在全域性傳播的關係。
將注意力對映(由黃色方框中的網路計算得出)加入到標準的卷積運算中。
自注意力生成對抗網路使影象生成任務能夠進行注意力機制驅動的遠距離依賴建模。自注意力機制是對於常規的卷積運算的補充。全域性資訊(遠距離依賴)有助於生成更高質量的影象。網路可以選擇忽略注意力機制,或將其與常規的卷積運算一同進行考慮。要想更細緻地瞭解自注意力機制,請參閱論文《 Self-Attention Generative Adversarial Networks 》。
9. 其它各種各樣的技術
下面是其它的一些被用來提升 GAN 模型效能的技術(不完全統計!):
-
特徵匹配
-
使用 Mini Batch 技術優化的判別器
-
歷史平均
-
單邊標籤平滑法
-
虛擬批量歸一化
你可以通過論文《Improved Techniques for Training GANs》以及博文《From GAN to WGAN》瞭解更多關於這些技術的資訊。在下面的 GitHub 程式碼倉庫中列舉出了更多的技術:https://github.com/soumith/ganhacks。
評價指標
到目前為止,讀者已經瞭解了提升 GAN 訓練效果的方法,我們需要使用一些指標來量化證明這些方法有效。下面,本文將列舉出一些常用的 GAN 模型的效能評價指標。
1. Inception(GoogleNet)得分
Inception 得分可以度量生成資料有多「真實」。
Inception Score 的計算方法。
上面的方程由兩個部分(p(y|x) 和 p(y))組成。在這裡,x 代表由生成器生成的影象,p(y|x) 是將影象 x 輸入給一個預訓練好的 Inception 網路(正如在原始實現中使用 ImageNet 資料集進行預訓練,https://arxiv.org/pdf/1801.01973.pdf)時得到的概率分佈。同時,p(y) 是邊緣概率分佈,可以通過對生成影象 x 的一些不同的樣本求 p(y|x) 平均值計算得出。這兩項代表了真實影象所需要滿足的兩種特性:
-
生成影象應該包含「有意義」的目標(清晰、不模糊的目標)。這就意味著 p(y|x) 應該具有「較小的熵」。也就是說,我們的 Inception 網路必須非常有把握地確定生成的影象從屬於某個特定的類。
-
生成的影象應該要「多樣」。這就意味著 p(y) 應該有「較大的熵」。換句話說,生成器應該在生成影象時使得每張影象代表不同類的標籤(理想情況下)。
理想狀況下 p(y|x) 和 p(y) 的示意圖。這種情況下,二者的 KL 散度非常大。
如果一個隨機變數是高度可預測的,那麼它的熵就很小(即,p(y) 應該是有一個尖峰的分佈)。相反,如果隨機變數是不可預測的,其熵就應該很大(即 p(y|x) 應該是一個均勻分佈)。如果這兩個特性都得到了滿足,我們應該認為 p(y|x) 和 p(y) 的 KL 散度很大。自然,Inception 得分(IS)越大越好。如果讀者想要了解對 Inception 得分更加深入的分析,請參閱論文《A Note on the Inception Score》(https://arxiv.org/pdf/1801.01973.pdf)。
2. Fréchet Inception 距離(FID)
Inception 得分的一個不足之處在於,並沒有對真實資料和生成資料的統計量(如均值和方差)進行比較。Fréchet 距離通過對比真實影象和生成影象的均值和方差解決了這個問題。Fréchet Inception 距離(FID)執行了與 Inception 得分相同的分析過程,但是它是在通過向預訓練好的 Inception-v3 網路傳入真實的和生成的影象後得到的特徵圖上完成的。FID 的公式如下所示:
FID 得分對比了真實的資料分佈和生成資料分佈的均值和方差。「Tr」代表矩陣的「跡」。
FID 得分越低越好,因為此時它表明生成影象的統計量與真實影象非常接近。
結語
為了克服 GAN 訓練中的種種弊端,研究社群提出了許多解決方案和方法。然而,由於大量湧現的新研究成果,很難跟進所有有意義的新工作。因此,本文分享的細節是不詳盡的,並且可能在不久的將來就會過時。但是,筆者希望本文可以為那些想要提高 GAN 模型效能的人提供一定的指導。