初始 GAN
本文大約 3800 字,閱讀大約需要 8 分鐘
要說最近幾年在深度學習領域最火的莫過於生成對抗網路,即 Generative Adversarial Networks(GANs)了。它是 Ian Goodfellow 在 2014 年發表的,也是這四年來出現的各種 GAN 的變種的開山鼻祖了,下圖表示這四年來有關 GAN 的論文的每個月發表數量,可以看出在 2014 年提出後到 2016 年相關的論文是比較少的,但是從 2016 年,或者是 2017 年到今年這兩年的時間,相關的論文是真的呈現井噴式增長。

那麼,GAN 究竟是什麼呢,它為何會成為這幾年這麼火的一個研究領域呢?
GAN,即生成對抗網路,是 一個生成模型,也是半監督和無監督學習模型,它可以在不需要大量標註資料的情況下學習深度表徵。最大的特點就是提出了一種讓兩個深度網路對抗訓練的方法。
目前機器學習按照資料集是否有標籤可以分為三種,監督學習、半監督學習和無監督學習,發展最成熟,效果最好的目前還是監督學習的方法,但是在資料集數量要求更多更大的情況下,獲取標籤的成本也更加昂貴了,因此越來越多的研究人員都希望能夠在無監督學習方面有更好的發展,而 GAN 的出現,一來它是不太需要很多標註資料,甚至可以不需要標籤,二來它可以做到很多事情,目前對它的應用包括影象合成、影象編輯、風格遷移、影象超解析度以及影象轉換等。
比如圖片的轉換, pix2pixGAN(https://github.com/affinelayer/pix2pix-tensorflow) 就可以做到,其結果如下圖所示,分割圖變成真實照片,從黑白圖變成彩色圖,從線條畫變成富含紋理、陰影和光澤的圖等等,這些都是這個 pix2pixGAN 實現的結果。

CycleGAN(https://github.com/junyanz/CycleGAN) 則可以做到風格遷移,其實現結果如下圖所示,真實照片變成印象畫,普通的馬和斑馬的互換,季節的變換等。

上述是 GAN 的一些應用例子,接下來會簡單介紹 GAN 的原理以及其優缺點,當然也還有為啥等它提出兩年後才開始有越來越多的 GAN 相關的論文發表。
1. 基本原理
GAN 的思想其實非常簡單,就是 生成器網路和判別器網路的彼此博弈。
GAN 主要就是兩個網路組成,生成器網路(Generator)和判別器網路(Discriminator),通過這兩個網路的互相博弈,讓生成器網路最終能夠學習到輸入資料的分佈,這也就是 GAN 想達到的目的-- 學習輸入資料的分佈 。其基本結構如下圖所示,從下圖可以更好理解G 和 D 的功能,分別為:
-
D 是判別器,負責對輸入的真實資料和由 G 生成的假資料進行判斷,其輸出是 0 和 1,即它本質上是一個二值分類器,目標就是對輸入為真實資料輸出是 1,對假資料的輸入,輸出是 0;
-
G 是生成器,它接收的是一個隨機噪聲,並生成影象。
在訓練的過程中,G 的目標是儘可能生成足夠真實的資料去迷惑 D,而 D 就是要將 G 生成的圖片都辨別出來,這樣兩者就是互相博弈,最終是要達到一個平衡,也就是納什均衡。

2. 優點
(以下優點和缺點主要來自 Ian Goodfellow 在 Quora 上的回答,以及知乎上的回答)
-
GAN 模型只用到了反向傳播,而不需要馬爾科夫鏈
-
訓練時不需要對隱變數做推斷
-
理論上,只要是可微分函式都可以用於構建 D 和 G ,因為能夠與深度神經網路結合做深度生成式模型
-
G 的引數更新不是直接來自資料樣本,而是使用來自 D 的反向傳播
-
相比其他生成模型(VAE、玻爾茲曼機),可以生成更好的生成樣本
-
GAN 是一種 半監督學習模型 ,對訓練集不需要太多有標籤的資料;
-
沒有必要遵循任何種類的因子分解去設計模型,所有的生成器和鑑別器都可以正常工作
3. 缺點
-
可解釋性差,生成模型的分佈
Pg(G)
沒有顯式的表達 -
比較難訓練, D 與 G 之間需要很好的同步,例如 D 更新 k 次而 G 更新一次
-
訓練 GAN 需要達到納什均衡,有時候可以用梯度下降法做到,有時候做不到.我們還沒有找到很好的達到納什均衡的方法,所以訓練 GAN 相比 VAE 或者 PixelRNN 是不穩定的,但我認為在實踐中它還是比訓練玻爾茲曼機穩定的多.
-
它很難去學習生成離散的資料,就像文字
-
相比玻爾茲曼機,GANs 很難根據一個畫素值去猜測另外一個畫素值,GANs 天生就是做一件事的,那就是一次產生所有畫素,你可以用 BiGAN 來修正這個特性,它能讓你像使用玻爾茲曼機一樣去使用 Gibbs 取樣來猜測缺失值
-
訓練不穩定, G 和 D 很難收斂
-
訓練還會遭遇 梯度消失、模式崩潰 的問題
-
缺乏比較有效的直接可觀的 評估模型生成效果的方法
3.1 為什麼訓練會出現梯度消失和模式奔潰
GAN 的本質就是 G 和 D 互相博弈並最終達到一個納什平衡點,但這只是一個理想的情況,正常情況是容易出現一方強大另一方弱小,並且一旦這個關係形成,而沒有及時找到方法平衡,那麼就會出現問題了。而梯度消失和模式奔潰其實就是這種情況下的兩個結果,分別對應 D 和 G 是強大的一方的結果。
首先對於梯度消失的情況是 D 越好,G 的梯度消失越嚴重 ,因為 G 的梯度更新來自 D,而在訓練初始階段,G 的輸入是隨機生成的噪聲,肯定不會生成很好的圖片,D 會很容易就判斷出來真假樣本,也就是 D 的訓練幾乎沒有損失,也就沒有有效的梯度資訊回傳給 G 讓 G 去優化自己。這樣的現象叫做 gradient vanishing,梯度消失問題。
其次,對於模式奔潰(mode collapse)問題,主要就是 G 比較強,導致 D 不能很好區分出真實圖片和 G 生成的假圖片,而如果此時 G 其實還不能完全生成足夠真實的圖片的時候,但 D 卻分辨不出來,並且給出了正確的評價,那麼 G 就會認為這張圖片是正確的,接下來就繼續這麼輸出這張或者這些圖片,然後 D 還是給出正確的評價,於是兩者就是這麼相互欺騙,這樣 G 其實就只會輸出固定的一些圖片,導致的結果除了生成圖片不夠真實,還有就是多樣性不足的問題。
更詳細的解釋可以參考 令人拍案叫絕的Wasserstein GAN(https://zhuanlan.zhihu.com/p/25071913),這篇文章更詳細解釋了原始 GAN 的問題,主要就是出現在 loss 函式上。
3.2 為什麼GAN不適合處理文字資料
-
文字資料相比較圖片資料來說是離散的,因為對於文字來說,通常需要將一個詞對映為一個高維的向量,最終預測的輸出是一個one-hot向量,假設 softmax 的輸出是
(0.2, 0.3, 0.1,0.2,0.15,0.05)
,那麼變為 onehot是(0,1,0,0,0,0),如果softmax輸出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot 仍然是(0, 1, 0, 0, 0, 0)
,所以對於生成器來說,G 輸出了不同的結果, 但是 D 給出了同樣的判別結果,並不能將梯度更新資訊很好的傳遞到 G 中去,所以 D 最終輸出的判別沒有意義。 -
GAN 的損失函式是 JS 散度,JS 散度不適合衡量不想交分佈之間的距離。(WGAN 雖然使用 wassertein 距離代替了 JS 散度,但是在生成文字上能力還是有限,GAN 在生成文字上的應用有 seq-GAN,和強化學習結合的產物)
3.3 為什麼GAN中的優化器不常用SGD
-
SGD 容易震盪,容易使 GAN 的訓練更加不穩定,
-
GAN 的目的是在高維非凸的引數空間中找到 納什均衡點 ,GAN 的納什均衡點是一個 鞍點 ,但是 SGD 只會找到 區域性極小值 ,因為 SGD 解決的是一個尋找最小值的問題,但 GAN 是一個博弈問題。
對於鞍點,來自百度百科的解釋是:
鞍點(Saddle point)在微分方程中,沿著某一方向是穩定的,另一條方向是不穩定的奇點,叫做鞍點。在泛函中,既不是極大值點也不是極小值點的臨界點,叫做鞍點。在矩陣中,一個數在所在行中是最大值,在所在列中是最小值,則被稱為鞍點。在物理上要廣泛一些,指在一個方向是極大值,另一個方向是極小值的點。
鞍點和區域性極小值點、區域性極大值點的區別如下圖所示:

4. 訓練的技巧
訓練的技巧主要來自 Tips and tricks to make GANs work--https://github.com/soumith/ganhacks。
1. 對輸入進行規範化
-
將輸入規範化到 -1 和 1 之間
-
G 的輸出層採用
Tanh
啟用函式
2. 採用修正的損失函式
在原始 GAN 論文中,損失函式 G 是要 , 但實際使用的時候是採用
Line"/> ,作者給出的原因是前者會導致梯度消失問題。
但實際上,即便是作者提出的這種實際應用的損失函式也是存在問題,即模式奔潰的問題,在接下來提出的 GAN 相關的論文中,就有不少論文是針對這個問題進行改進的,如 WGAN 模型就提出一種新的損失函式。
3. 從球體上取樣噪聲

-
更多細節可以參考 Tom White's 的論文 Sampling Generative Networks(https://arxiv.org/abs/1609.04468) 以及程式碼 https://github.com/dribnet/plat
4. BatchNorm
-
採用 mini-batch BatchNorm,要保證每個 mini-batch 都是同樣的真實圖片或者是生成圖片
-
不採用 BatchNorm 的時候,可以採用 instance normalization(對每個樣本的規範化操作)
-
可以使用 虛擬批量歸一化 (virtural batch normalization):開始訓練之前預定義一個 batch R,對每一個新的 batch X,都使用 R+X 的級聯來計算歸一化引數
5. 避免稀疏的梯度:Relus、MaxPool
-
稀疏梯度會影響 GAN 的穩定性
-
在 G 和 D 中採用 LeakyReLU 代替 Relu 啟用函式
-
對於下采樣操作,可以採用平均池化(Average Pooling) 和 Conv2d+stride 的替代方案
-
對於上取樣操作,可以使用 PixelShuffle(https://arxiv.org/abs/1609.05158), ConvTranspose2d + stride
6. 標籤的使用
-
標籤平滑。也就是如果有兩個目標標籤,假設真實圖片標籤是 1,生成圖片標籤是 0,那麼對每個輸入例子,如果是真實圖片,採用 0.7 到 1.2 之間的一個隨機數字來作為標籤,而不是 1;一般是採用單邊標籤平滑
-
在訓練 D 的時候,偶爾翻轉標籤
-
有標籤資料就儘量使用標籤
7. 使用 Adam 優化器
8. 儘早追蹤失敗的原因
-
D 的 loss 變成 0,那麼這就是訓練失敗了
-
檢查規範的梯度:如果超過 100,那出問題了
-
如果訓練正常,那麼 D loss 有低方差並且隨著時間降低
-
如果 g loss 穩定下降,那麼它是用糟糕的生成樣本欺騙了 D
9. 不要通過統計學來平衡 loss
10. 給輸入新增噪聲
-
給 D 的輸入新增人為的噪聲
-
http://www.inference.vc/instance-noise-a-trick-for-stabilising-gan-training/
-
https://openreview.net/forum?id=Hk4_qw5xe
-
給 G 的每層都新增高斯噪聲
11. 對於 Conditional GANs 的離散變數
-
使用一個 Embedding 層
-
對輸入圖片新增一個額外的通道
-
保持 embedding 低維並通過上取樣操作來匹配影象的通道大小
12 在 G 的訓練和測試階段使用 Dropouts
-
以 dropout 的形式提供噪聲(50%的概率)
-
訓練和測試階段,在 G 的幾層使用
-
https://arxiv.org/pdf/1611.07004v1.pdf
參考文章:
-
Goodfellow et al., “Generative Adversarial Networks”. ICLR 2014.--https://arxiv.org/abs/1406.2661
-
ofollow,noindex" target="_blank">GAN系列學習(1)——前生今世
-
OA==&mid=2652692740&idx=1&sn=f1b134f63eb0bf5e4d6759db4d740e58&scene=21#wechat_redirect" rel="nofollow,noindex" target="_blank">乾貨 | 深入淺出 GAN·原理篇文字版(完整)
-
令人拍案叫絕的Wasserstein GAN--https://zhuanlan.zhihu.com/p/25071913
-
生成對抗網路(GAN)相比傳統訓練方法有什麼優勢?--https://www.zhihu.com/question/56171002/answer/148593584
-
https://github.com/hindupuravinash/the-gan-zoo
-
What-is-the-advantage-of-generative-adversarial-networks-compared-with-other-generative-models--https://www.quora.com/What-is-the-advantage-of-generative-adversarial-networks-compared-with-other-generative-models
-
What-are-the-pros-and-cons-of-using-generative-adversarial-networks-a-type-of-neural-network-Could-they-be-applied-to-things-like-audio-waveform-via-RNN-Why-or-why-not--https://www.quora.com/What-are-the-pros-and-cons-of-using-generative-adversarial-networks-a-type-of-neural-network-Could-they-be-applied-to-things-like-audio-waveform-via-RNN-Why-or-why-not
-
https://github.com/soumith/ganhacks
注:配圖來自網路和參考文章
以上就是本文的主要內容和總結,因為我還沒有開通留言功能,另外公眾號不能新增外鏈,可以點選左下角原文檢視可以點選連結的文章,並且還可以留言給出你對本文的建議和看法。
同時也歡迎關注我的微信公眾號--機器學習與計算機視覺或者掃描下方的二維碼,和我分享你的建議和看法,指正文章中可能存在的錯誤,大家一起交流,學習和進步!
