1. 程式人生 > >Wasserstein Generative Adversarial Nets(WGAN)

Wasserstein Generative Adversarial Nets(WGAN)

權重 https mathjax blank min www. margin xmlns nbsp

GAN目前是機器學習中非常受歡迎的研究方向。主要包括有兩種類型的研究,一種是將GAN用於有趣的問題,另一種是試圖增加GAN的模型穩定性。

事實上,穩定性在GAN訓練中是非常重要的。起初的GAN模型在訓練中存在一些問題,e.g., 模式塌陷生成器演化成非常窄的分布,只覆蓋數據分布中的單一模式)。模式塌陷的含義是發生器只能產生非常相似的樣本(例如MNIST中的單個數字),即所產生的樣本不是多樣的。這當然違反了GAN初衷

GAN中的另一個問題是沒有指很好的指標或度量說明模型的收斂性生成器鑒別器損失並沒有告訴我們關於這方面的任何信息。當然,我們可以通過查看生成器產生的數據來監控訓練過程。但是,這是一個愚蠢的手動過程。所以,我們需要一個可解釋

指標告訴我們訓練過程的好壞。

Wasserstein GAN

Wasserstein GAN(WGAN)是一種新提出的GAN算法,可以在一定程度解決上述兩個問題。對於WGAN背後的直覺和理論背景,可以查看相關資料。

整個算法的偽代碼如下:

技術分享圖片

我們可以看到該算法與原始GAN算法非常相似。 但是,對於WGAN,我們根據上面的代碼需要註意到下幾點:
  1. 損失函數中沒有log。判別器D(X)的輸出不再是一個概率(標量),同時也就意味著沒有sigmoid激活函數
  2. 對於判別器D(X)的權重W進行裁剪
  3. 訓練判別器的次數生成器
  4. 采用RMSProp優化器,代替原先的ADAM
    優化器
  5. 非常低的learning rate, α=0.00005

WGAN TensorFlow implementation

GAN的基本實現可以在上一篇文章中介紹過。 我們只需要稍微修改下傳統的GAN。 首先,讓我們更新我們的判別器D(X)

技術分享圖片
""" Vanilla GAN """
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    out = tf.matmul(D_h1, D_W2) + D_b2
    return tf.nn.sigmoid(out)

""" WGAN """ def discriminator(x): D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) out = tf.matmul(D_h1, D_W2) + D_b2 return out
View Code

接下來,修改loss函數,去掉log

技術分享圖片
""" Vanilla GAN """
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))

""" WGAN """
D_loss = tf.reduce_mean(D_real) - tf.reduce_mean(D_fake)
G_loss = -tf.reduce_mean(D_fake)
View Code

在每次梯度下降更新後,裁剪判別器D(X)的權重:

# theta_D is list of D‘s params
clip_D = [p.assign(tf.clip_by_value(p, -0.01, 0.01)) for p in theta_D]

然後,只需要訓練更多次的判別器D(X)就行了

技術分享圖片
D_solver = (tf.train.RMSPropOptimizer(learning_rate=5e-5)
            .minimize(-D_loss, var_list=theta_D))
G_solver = (tf.train.RMSPropOptimizer(learning_rate=5e-5)
            .minimize(G_loss, var_list=theta_G))

for it in range(1000000):
    for _ in range(5):
        X_mb, _ = mnist.train.next_batch(mb_size)

        _, D_loss_curr, _ = sess.run([D_solver, D_loss, clip_D], feed_dict={X: X_mb, z: sample_z(mb_size, z_dim)})

    _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={z: sample_z(mb_size, z_dim)})
View Code

Wasserstein Generative Adversarial Nets(WGAN)