1. 程式人生 > >實戰生成對抗網路[3]:DCGAN

實戰生成對抗網路[3]:DCGAN

在上一篇文章《實戰生成對抗網路[2]:生成手寫數字》中,我們使用了簡單的神經網路來生成手寫數字,可以看出手寫數字字形,但不夠完美,生成的手寫數字有些毛糙,邊緣不夠平滑。

生成對抗網路中,生成器和判別器是一對冤家。要提高生成器的水平,就要提高判別器的識別能力。在《一步步提高手寫數字的識別率(3)》系列文章中,我們探討了如何提高手寫數字的識別率,發現卷積神經網路在影象處理方面優勢巨大,最後採用卷積神經網路模型,達到一個不錯的識別率。自然的,為了提高生成對抗網路的手寫數字生成質量,我們是否也可以採用卷積神經網路呢?

答案是肯定的,不過和《一步步提高手寫數字的識別率(3)》中隨便採用一個卷積神經網路結構是不夠的,因為生成對抗網路中,有兩個神經網路模型互相對抗,隨便選擇網路結構,容易在迭代過程中引起振盪,難以收斂。

好在有專家學者進行了這方面的研究,下面就介紹一篇由Alec Radford、Luke Metz和Soumith Chintala合作完成的論文 arXiv: 1511.06434, 《利用深度卷積生成對抗網路進行無監督表徵學習(Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks)》。

論文給出了生成器的模型結構,如下圖所示:

從圖中可以看,該網路採用100x1噪聲向量(隨機輸入),表示為z,並將其對映到G(Z)輸出,即64x64x3,其變換過程為:

100x1 → 1024x4x4 → 512x8x8 → 256x16x16 → 128x32x32 → 64x64x3

如果採用keras實現上述模型,非常簡單。不過需要注意的是,在本文中探討的手寫數字生成,其最終輸出是28 x 28 x 1的灰度圖片,所以我們沿襲上面的模型架構,但在具體實現上做一些調整:

100x1 → 1024x1 → 128x7x7 → 128x14x14 → 14x14x64 → 28x28x64 → 8x28x1

程式碼如下:

def generator_model():
  model = Sequential()
  model.add(Dense(input_dim=100, output_dim=1024))
  model.add(Activation('tanh'
)) model.add(Dense(128 * 7 * 7)) model.add(BatchNormalization()) model.add(Activation('tanh')) model.add(Reshape((7, 7, 128), input_shape=(128 * 7 * 7,))) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(64, (5, 5), padding='same')) model.add(Activation('tanh')) model.add(UpSampling2D(size=(2, 2))) model.add(Conv2D(1, (5, 5), padding='same')) model.add(Activation('tanh')) return model 複製程式碼

程式碼中引入了批量規則化(BatchNormalization),在實踐中被證實可以在許多場合提升訓練速度,減少初始化不佳帶來的問題並且通常能產生準確的結果。上取樣則是用來擴大維度。

判別器的實現差不多是將上述生成器模型倒過來實現,但使用最大池化代替了上取樣,程式碼如下:

def discriminator_model():
  model = Sequential()
  model.add(
      Conv2D(64, (5, 5),
             padding='same',
             input_shape=(28, 28, 1))
  )
  model.add(Activation('tanh'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  model.add(Conv2D(128, (5, 5)))
  model.add(Activation('tanh'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  model.add(Flatten())
  model.add(Dense(1024))
  model.add(Activation('tanh'))
  model.add(Dense(1))
  model.add(Activation('sigmoid'))
  return model
複製程式碼

在論文中,作者建議通過下面一些架構性的約束來固化網路:

  • 在判別器中使用跨步卷積取代池化層,在生成器中使用反捲積取代池化層。
  • 在生成器和判別器中使用批量規則化。
  • 消除架構中較深的全連線層。
  • 在生成器的輸出層使用Tanh,在其他層均使用ReLU啟用。
  • 在判別器的所有層中都使用LeakyReLU啟用。

上述程式碼並沒有完全遵守作者的建議,可見在面對不同的場景,開發者可以有自己的發揮。事實上,在GANs in Action這本書中,作者也給出了手寫數字生成的另外一種DCGAN模型,程式碼可參考:github.com/GANs-in-Act…

經過100個epoch的迭代,我們的程式碼生成的手寫數字如下圖所示,雖然有些數字生成得不太準確,不過相對於上一篇文章的輸出,邊緣還是要平滑一些,效果也有所改進:

本文所演示內容的完整程式碼,請參考:github.com/mogoweb/aie…

image