1. 程式人生 > >在TensorFlow中對比兩大生成模型:VAE與GAN

在TensorFlow中對比兩大生成模型:VAE與GAN

變分自編碼器(VAE)與生成對抗網路(GAN)是複雜分佈上無監督學習最具前景的兩類方法。本文中,作者在 MNIST 上對這兩類生成模型的效能進行了對比測試。

本專案總結了使用變分自編碼器(Variational Autoencode,VAE)和生成對抗網路(GAN)對給定資料分佈進行建模,並且對比了這些模型的效能。你可能會問:我們已經有了數百萬張影象,為什麼還要從給定資料分佈中生成影象呢?正如 Ian Goodfellow 在 NIPS 2016 教程中指出的那樣,實際上有很多應用。我覺得比較有趣的一種是使用 GAN 模擬可能的未來,就像強化學習中使用策略梯度的智慧體那樣。

 本文組織架構:

  •  變分自編碼器(VAE)
  •  生成對抗網路(GAN)
  •  訓練普通 GAN 的難點
  •  訓練細節
  •  在 MNIST 上進行 VAE 和 GAN 對比實驗

        在無標籤的情況下訓練 GAN 判別器


        在有標籤的情況下訓練 GAN 判別器

  •  在 CIFAR 上進行 VAE 和 GAN 實驗
  •  延伸閱讀

VAE

 變分自編碼器可用於對先驗資料分佈進行建模。從名字上就可以看出,它包括兩部分:編碼器和解碼器。編碼器將資料分佈的高階特徵對映到資料的低階表徵,低階表徵叫作本徵向量(latent vector)。解碼器吸收資料的低階表徵,然後輸出同樣資料的高階表徵。

 從數學上來講,讓 X 作為編碼器的輸入,z 作為本徵向量,X′作為解碼器的輸出。

 圖 1 是 VAE 的視覺化圖。

06859image%20(12).png

圖 1:VAE 的架構

 這與標準自編碼器有何不同?關鍵區別在於我們對本徵向量的約束。如果是標準自編碼器,那麼我們主要關注重建損失(reconstruction loss),即:

6636320171023093000.png

而在變分自編碼器的情況中,我們希望本徵向量遵循特定的分佈,通常是單位高斯分佈(unit Gaussian distribution),使下列損失得到優化:

9857720171023093026.png

p(z′)∼N(0,I) 中 I 指單位矩陣(identity matrx),q(z∣X) 是本徵向量的分佈,其中29636%E5%B1%8F%E5%B9%95%E5%BF%AB%E7%85%A7%202017-10-23%20%E4%B8%8A%E5%8D%8810.15.42.png62957%E5%B1%8F%E5%B9%95%E5%BF%AB%E7%85%A7%202017-10-23%20%E4%B8%8A%E5%8D%8810.15.59.png81364%E5%B1%8F%E5%B9%95%E5%BF%AB%E7%85%A7%202017-10-23%20%E4%B8%8A%E5%8D%8810.16.06.png由神經網路來計算。KL(A,B) 是分佈 B 到 A 的 KL 散度。

 由於損失函式中還有其他項,因此存在模型生成影象的精度和本徵向量的分佈與單位高斯分佈的接近程度之間存在權衡(trade-off)。這兩部分由兩個超引數λ_1 和λ_2 來控制。

GAN

 GAN 是根據給定的先驗分佈生成資料的另一種方式,包括同時進行的兩部分:判別器和生成器。

 判別器用於對「真」影象和「偽」影象進行分類,生成器從隨機噪聲中生成影象(隨機噪聲通常叫作本徵向量或程式碼,該噪聲通常從均勻分佈(uniform distribution)或高斯分佈中獲取)。生成器的任務是生成可以以假亂真的影象,令判別器也無法區分出來。也就是說,生成器和判別器是互相對抗的。判別器非常努力地嘗試區分真偽影象,同時生成器盡力生成更加逼真的影象,使判別器將這些影象也分類為「真」影象。

 圖 2 是 GAN 的典型結構。

20162image%20(13).png

圖 2:GAN

 生成器包括利用程式碼輸出影象的解卷積層。圖 3 是生成器的架構圖。

86657image%20(14).png

圖 3:典型 GAN 的生成器圖示(影象來源:OpenAI)

訓練 GAN 的難點

 訓練 GAN 時我們會遇到一些挑戰,我認為其中最大的挑戰在於本徵向量/程式碼的取樣。程式碼只是從先驗分佈中對本徵變數的噪聲取樣。有很多種方法可以克服該挑戰,包括:使用 VAE 對本徵變數進行編碼,學習資料的先驗分佈。這聽起來要好一些,因為編碼器能夠學習資料分佈,現在我們可以從分佈中進行取樣,而不是生成隨機噪聲。

訓練細節

 我們知道兩個分佈 p(真實分佈)和 q(估計分佈)之間的交叉熵通過以下公式計算:

5029220171023093127.png

對於二元分類,

6747520171023093159.png

對於 GAN,我們假設分佈的一半來自真實資料分佈,一半來自估計分佈,因此:

8384220171023093238.png

訓練 GAN 需要同時優化兩個損失函式。

按照極小極大值演算法,

2360320171023093303.png

這裡,判別器需要區分影象的真偽,不管影象是否包含真實物體,都沒有注意力。當我們在 CIFAR 上檢查 GAN 生成的影象時會明顯看到這一點。

 我們可以重新定義判別器損失目標,使之包含標籤。這被證明可以提高主觀樣本的質量。

 如:在 MNIST 或 CIFAR-10(兩個資料集都有 10 個類別)。

上述 Python 損失函式在 TensorFlow 中的實現:

          def VAE_loss(true_images, logits, mean, std):
      """
        Args:
          true_images : batch of input images
          logits      : linear output of the decoder network (the constructed images)
          mean        : mean of the latent code
          std         : standard deviation of the latent code
      """
      imgs_flat    = tf.reshape(true_images, [-1, img_h*img_w*img_d])
      encoder_loss = 0.5 * tf.reduce_sum(tf.square(mean)+tf.square(std)
                     -tf.log(tf.square(std))-1, 1)
      decoder_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                     logits=logits, labels=img_flat), 1)
      return tf.reduce_mean(encoder_loss + decoder_loss)
  
  
  def GAN_loss_without_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a column vector)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a column vector)
      """

      true_prob = tf.nn.sigmoid(true_logit)
      fake_prob = tf.nn.sigmoid(fake_logit)
      d_loss = tf.reduce_mean(-tf.log(true_prob)-tf.log(1-fake_prob))
      g_loss = tf.reduce_mean(-tf.log(fake_prob))
      return d_loss, g_loss
  
  
  def GAN_loss_with_labels(true_logit, fake_logit):
      """
        Args:
          true_logit : Given data from true distribution,
                      `true_logit` is the output of Discriminator (a matrix now)
          fake_logit : Given data generated from Generator,
                      `fake_logit` is the output of Discriminator (a matrix now)
      """
      d_true_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.true_logit, dim=1)
      d_fake_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=1-self.labels, logits=self.fake_logit, dim=1)
      g_loss = tf.nn.softmax_cross_entropy_with_logits(
                    labels=self.labels, logits=self.fake_logit, dim=1)

      d_loss = d_true_loss + d_fake_loss
      return tf.reduce_mean(d_loss), tf.reduce_mean(g_loss)
  
      

在 MNIST 上進行 VAE 與 GAN 對比實驗

#1 不使用標籤訓練判別器

我在 MNIST 上訓練了一個 VAE。程式碼地址:https://github.com/kvmanohar22/Generative-Models

實驗使用了 MNIST 的 28×28 影象,下圖中:

  • 左側:資料分佈的 64 張原始影象
  • 中間:VAE 生成的 64 張影象
  • 右側:GAN 生成的 64 張影象

第 1 次迭代

2553420171023093407.png

第 2 次迭代

4630920171023093431.png

第 3 次迭代

7198320171023093455.png

第 4 次迭代

4931120171023093522.png

第 100 次迭代

6122220171023093547.png

VAE(125)和 GAN(368)訓練的最終結果

8013820171023093611.png

以下動圖展示了 GAN 生成影象的過程(模型訓練了 368 個 epoch)。

69350gan.gif

顯然,VAE 生成的影象與 GAN 生成的影象相比,前者更加模糊。這個結果在預料之中,因為 VAE 模型生成的所有輸出都是分佈的平均。為了減少影象的模糊,我們可以使用 L1 損失來代替 L2 損失。

99500gan_gif.gif

在第一個實驗後,作者還將在近期研究使用標籤訓練判別器,並在 CIFAR 資料集上測試 VAE 與 GAN 的效能。

使用

  • 下載 MNIST 和 CIFAR 資料集

使用 MNIST 訓練 VAE 請執行:

        python main.py --train --model vae --dataset mnist
      

使用 MNIST 訓練 GAN 請執行:

        python main.py --train --model gan --dataset mnist
      

想要獲取完整的命令列選項,請執行:

        python main.py --help
      

該模型由 generate_frq 決定生成圖片的頻率,預設值為 1。

GAN 在 MNIST 上的訓練結果

MNIST 資料集中的樣本影象:

52213target_mnist.jpg

上方是 VAE 生成的影象,下方的動圖展示了 GAN 生成影象的過程:

12004vae.jpg

39190gan.gif