1. 程式人生 > >【神經網路】GAN原理總結,CatGAN

【神經網路】GAN原理總結,CatGAN

定義及原理:    

       生成器 (G)generator:接收一個隨機的噪聲z(隨機數),通過這個噪聲生成影象。G的目標就是儘量生成真實的圖片去欺騙判別網路D。

       判別器(D) discriminator:對接收的圖片進行真假判別。它的輸入引數是x,x代表一張圖片,輸出D(x)代表x為真實圖片的概率,如果為1,就代表100%是真實的圖片,而輸出為0,就代表不可能是真實的圖片。D的目標就是儘量辨別出G生成的假影象和真實的影象。

       GAN的主要靈感來源於博弈論中零和博弈的思想,應用到深度學習神經網路上來說,就是通過G和D不斷博弈,進而使G學習到資料的分佈,如果用到圖片生成上,則訓練完成後,G可以從一段隨機數中生成逼真的影象。

      訓練過程中,G和D構成了一個動態的“博弈過程”,最終的平衡點即納什均衡點:生成器生成的影象接近於真實影象分佈,而判別器識別不出真假影象,對於給定影象的預測為真的概率基本接近 0.5(相當於隨機猜測類別)

過程

  1. 第一代的Generator,然後他產生一些圖片
  2. 訓練產生第一代discriminator,能夠區分人工產生的和真實的圖片
  3. 訓練第二代Generator,使其產生的圖片騙過第一代discriminator
  4. 以此類推。。。

優點

  1. 只用到了反向傳播
  2. 相比其他所有模型, GAN可以產生更加清晰,真實的樣本
  3. GAN應用到一些場景上,比如圖片風格遷移,超解析度,影象補全,去噪,避免了損失函式設計的困難,不管三七二十一,只要有一個的基準,直接上判別器,剩下的就交給對抗訓練了

缺點

  1. 訓練GAN需要達到納什均衡,有時候可以用梯度下降法做到,有時候做不到.我們還沒有找到很好的達到納什均衡的方法,所以訓練GAN相比VAE或者PixelRNN是不穩定的,但我認為在實踐中它還是比訓練玻爾茲曼機穩定的多
  2. GAN不適合處理離散形式的資料,比如文字
  3. GAN存在訓練不穩定、梯度消失、模式崩潰的問題(目前已解決)

應用

  1. 圖片生成
  2. 替換判別器為一個分類器,做多分類任務,而生成器仍然做生成任務,輔助分類器訓練
  3. 和強化學習結合,目前一個比較好的例子就是seq-GAN

CatGAN

無監督的分類會被轉化為一個聚類問題,通常是以某種距離作為度量準則,從而將資料劃分為多個類別,而本文則是採用資料的熵來作為衡量標準構建來CatGAN (ICLR-2016) 。具體來說,對於真實的資料

,模型希望判別器不僅能具有較大的確信度將其劃分為真實樣本,同時還有較大的確信度將資料劃分到某一個現有的類別中去;而對於生成資料卻不是十分確定要將其劃分到哪一個現有的類別,也就是這個不確信度比較大,從而生成器的目標即為產生出那些“將其劃分到某一類別中去”的確信度較高的樣本,嘗試騙過判別器。接下來,為了衡量這個確信程度,作者用熵來表示,熵值越大,即為越不確定;而熵值越小,則表示越確定。然後,將該確信度目標與原始GAN的真偽鑑別的優化目標結合,即得到了CatGAN的最終優化目標。

對於半監督的情況,對有標籤資料計算交叉熵損失,而對無標籤資料計算上面的基於熵的損失,然後在原來的目標函式的基礎上進行疊加即得,當用該半監督方法進行目標識別與分類時,其效果雖然相對較優,但相對當下state-of-the-art的方法並沒有比較明顯的提升。但其基於熵損失的無監督訓練方法卻表現較好,其實驗效果如下圖所示,可以看到,對於如下的典型環形資料,CatGAN可以較好地找到兩者的分類面,實現無監督聚類的功能。

GAN of Salimans et al. (2016)

GAN網路使用梯度下降的方法只會找到低的損失,不能找到真正的納什均衡。本論文中,作者通過引入了一些方法,提高網路的收斂。

原始的GAN網路的目標函式需要最大化判別網路的輸出。作者提出了新的目標函式,motivation就是讓生成網路產生的圖片,經過判別網路後的中間層的feature 和真實圖片經過判別網路的feature儘可能相同。

相比原先的方式,生成網路G產生的資料更符合資料的真實分佈。作者雖然不保證能夠收斂到納什均衡點,但是在傳統GAN不能穩定收斂的情況下,新的目標函式仍然有效。

判別網路從輸入到輸出逐層卷積,pooling,圖片資訊逐漸損失,因此中間層能夠比輸出層得到更好的原始圖片的分佈資訊,拿中間層的feature作為目標函式比輸出層的結果,能夠生成圖片資訊更多,生成的圖片會效果會更好。

  • Semi-supervised learning

對於GAN網路,可以把生成網路的輸出作為第K+1類,相應的判別網路變為K+1類的分類問題。用Pmodel(y=K+1|x)Pmodel(y=K+1|x)表示生成網路的圖片為假