1. 程式人生 > >深度學習----GAN(生成對抗神經網路)原理解析

深度學習----GAN(生成對抗神經網路)原理解析

一、原理部分

首先附上一張流程圖
這裡寫圖片描述

1.1、 GAN的原理:

GAN的主要靈感來源於博弈論中零和博弈的思想,應用到深度學習神經網路上來說,就是通過生成網路G(Generator)和判別網路D(Discriminator)不斷博弈,進而使G學習到資料的分佈,如果用到圖片生成上,則訓練完成後,G可以從一段隨機數中生成逼真的影象。G, D的主要功能是:

● G是一個生成式的網路,它接收一個隨機的噪聲z(隨機數),通過這個噪聲生成影象

● D是一個判別網路,判別一張圖片是不是“真實的”。它的輸入引數是x,x代表一張圖片,輸出D(x)代表x為真實圖片的概率,如果為1,就代表100%是真實的圖片,而輸出為0,就代表不可能是真實的圖片
這裡寫圖片描述

訓練過程中,生成網路G的目標就是儘量生成真實的圖片去欺騙判別網路D。而D的目標就是儘量辨別出G生成的假影象和真實的影象。這樣,G和D構成了一個動態的“博弈過程”,最終的平衡點即納什均衡點.

1.2、架構

這裡寫圖片描述

通過優化目標,使得我們可以調節概率生成模型的引數\theta,從而使得生成的概率分佈和真實資料分佈儘量接近。

那麼怎麼去定義一個恰當的優化目標或一個損失?傳統的生成模型,一般都採用資料的似然性來作為優化的目標,但GAN創新性地使用了另外一種優化目標。首先,它引入了一個判別模型(常用的有支援向量機和多層神經網路)。其次,它的優化過程就是在尋找生成模型和判別模型之間的一個納什均衡。

GAN所建立的一個學習框架,實際上就是生成模型和判別模型之間的一個模仿遊戲。生成模型的目的,就是要儘量去模仿、建模和學習真實資料的分佈規律;而判別模型則是要判別自己所得到的一個輸入資料,究竟是來自於真實的資料分佈還是來自於一個生成模型。通過這兩個內部模型之間不斷的競爭,從而提高兩個模型的生成能力和判別能力。

這裡寫圖片描述

當一個判別模型的能力已經非常強的時候,如果生成模型所生成的資料,還是能夠使它產生混淆,無法正確判斷的話,那我們就認為這個生成模型實際上已經學到了真實資料的分佈。

1.3、 GAN 的特點及優缺點:

特點

● 相比較傳統的模型,他存在兩個不同的網路,而不是單一的網路,並且訓練方式採用的是對抗訓練方式

● GAN中G的梯度更新資訊來自判別器D,而不是來自資料樣本

優點

(以下部分摘自ian goodfellow 在Quora的問答)

● GAN是一種生成式模型,相比較其他生成模型(玻爾茲曼機和GSNs)只用到了反向傳播,而不需要複雜的馬爾科夫鏈

● 相比其他所有模型, GAN可以產生更加清晰,真實的樣本

● GAN採用的是一種無監督的學習方式訓練,可以被廣泛用在無監督學習和半監督學習領域

● 相比於變分自編碼器, GANs沒有引入任何決定性偏置( deterministic bias),變分方法引入決定性偏置,因為他們優化對數似然的下界,而不是似然度本身,這看起來導致了VAEs生成的例項比GANs更模糊

● 相比VAE, GANs沒有變分下界,如果鑑別器訓練良好,那麼生成器可以完美的學習到訓練樣本的分佈.換句話說,GANs是漸進一致的,但是VAE是有偏差的

● GAN應用到一些場景上,比如圖片風格遷移,超解析度,影象補全,去噪,避免了損失函式設計的困難,不管三七二十一,只要有一個的基準,直接上判別器,剩下的就交給對抗訓練了。

缺點

● 訓練GAN需要達到納什均衡,有時候可以用梯度下降法做到,有時候做不到.我們還沒有找到很好的達到納什均衡的方法,所以訓練GAN相比VAE或者PixelRNN是不穩定的,但我認為在實踐中它還是比訓練玻爾茲曼機穩定的多

● GAN不適合處理離散形式的資料,比如文字

GAN存在訓練不穩定、梯度消失、模式崩潰的問題(目前已解決)

附:模式崩潰(model collapse)原因

一般出現在GAN訓練不穩定的時候,具體表現為生成出來的結果非常差,但是即使加長訓練時間後也無法得到很好的改善。

具體原因可以解釋如下:
GAN採用的是對抗訓練的方式,G的梯度更新來自D,所以G生成的好不好,得看D怎麼說。具體就是G生成一個樣本,交給D去評判,D會輸出生成的假樣本是真樣本的概率(0-1),相當於告訴G生成的樣本有多大的真實性,G就會根據這個反饋不斷改善自己,提高D輸出的概率值。但是如果某一次G生成的樣本可能並不是很真實,但是D給出了正確的評價,或者是G生成的結果中一些特徵得到了D的認可,這時候G就會認為我輸出的正確的,那麼接下來我就這樣輸出肯定D還會給出比較高的評價,實際上G生成的並不怎麼樣,但是他們兩個就這樣自我欺騙下去了,導致最終生成結果缺失一些資訊,特徵不全。
這裡寫圖片描述

區域性極小值點

這裡寫圖片描述

鞍點

二、為什麼GAN中的優化器不常用SGD

  1. SGD容易震盪,容易使GAN訓練不穩定,

  2. GAN的目的是在高維非凸的引數空間中找到納什均衡點,GAN的納什均衡點是一個鞍點,但是SGD只會找到區域性極小值,因為SGD解決的是一個尋找最小值的問題,GAN是一個博弈問題。

三、為什麼GAN不適合處理文字資料

  1. 文字資料相比較圖片資料來說是離散的,因為對於文字來說,通常需要將一個詞對映為一個高維的向量,最終預測的輸出是一個one-hot向量,假設softmax的輸出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那麼變為onehot是(0,1,0,0,0,0),如果softmax輸出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以對於生成器來說,G輸出了不同的結果但是D給出了同樣的判別結果,並不能將梯度更新資訊很好的傳遞到G中去,所以D最終輸出的判別沒有意義。

  2. 另外就是GAN的損失函式是JS散度,JS散度不適合衡量不想交分佈之間的距離。

(WGAN雖然使用wassertein距離代替了JS散度,但是在生成文字上能力還是有限,GAN在生成文字上的應用有seq-GAN,和強化學習結合的產物)

四、訓練GAN的一些技巧

  1. 輸入規範化到(-1,1)之間,最後一層的啟用函式使用tanh(BEGAN除外)

  2. 使用wassertein GAN的損失函式,

  3. 如果有標籤資料的話,儘量使用標籤,也有人提出使用反轉標籤效果很好,另外使用標籤平滑,單邊標籤平滑或者雙邊標籤平滑

  4. 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm

  5. 避免使用RELU和pooling層,減少稀疏梯度的可能性,可以使用leakrelu啟用函式

  6. 優化器儘量選擇ADAM,學習率不要設定太大,初始1e-4可以參考,另外可以隨著訓練進行不斷縮小學習率,

  7. 給D的網路層增加高斯噪聲,相當於是一種正則

GAN的變種

自從GAN出世後,得到了廣泛研究,先後幾百篇不同的GANpaper橫空出世,國外有大神整理了一個GAN zoo(GAN動物園),連結如下,感興趣的可以參考一下:

GitHub上已經1200+star了,順便附上一張GAN的成果圖,可見GAN的研究火熱程度:

五、GAN的廣泛應用

  1. GAN本身是一種生成式模型,所以在資料生成上用的是最普遍的,最常見的是圖片生成,常用的有DCGAN WGAN,BEGAN,個人感覺在BEGAN的效果最好而且最簡單。

  2. GAN本身也是一種無監督學習的典範,因此它在無監督學習,半監督學習領域都有廣泛的應用

  3. 不僅在生成領域,GAN在分類領域也佔有一席之地,簡單來說,就是替換判別器為一個分類器,做多分類任務,而生成器仍然做生成任務,輔助分類器訓練。

  4. GAN可以和強化學習結合,目前一個比較好的例子就是seq-GAN

  5. 目前比較有意思的應用就是GAN用在影象風格遷移,影象降噪修復,影象超解析度了,都有比較好的結果

  6. 目前也有研究者將GAN用在對抗性攻擊上,具體就是訓練GAN生成對抗文字,有針對或者無針對的欺騙分類器或者檢測系統等等,但是目前沒有見到很典範的文章。

參考文獻: