1. 程式人生 > >對抗網路的簡單版--手寫數字MNIST的訓練

對抗網路的簡單版--手寫數字MNIST的訓練

1、對抗網路是一種資料驅動的網路。人為干預比較少。

其中生成網路的損失利用了鑑別器的損失。而鑑別器的資料資料輸入利用了生成網路的生成資料跟真實資料。

兩個網路再權重更新是互不干擾。都只更新自身的權重值。

下面是簡單gan網路的實現程式碼:

import tensorflow as tf
import numpy as np
import tensorflow.examples.tutorials.mnist as input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os.path as op
import os

mnist = input_data.input_data.read_data_sets("MNIST_data/", one_hot=True)

X_dim = 784
Z_dim = 100
batch_size = 128

# 真實的輸入影象的佔位符
X = tf.placeholder(tf.float32, shape=[None, X_dim])
# 輸入的用於生成器的輸入資料的佔位符
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
# 干擾樣本
X_per = tf.placeholder(tf.float32, shape=[None, X_dim])
# 鑑別器使用的權重
D_W1 = tf.Variable(tf.truncated_normal([X_dim, 100], stddev=0.1), dtype=tf.float32)
D_b1 = tf.Variable(tf.constant(0.1, shape=[100]))
D_W2 = tf.Variable(tf.truncated_normal([100, 1], stddev=0.1), dtype=tf.float32)
D_b2 = tf.Variable(tf.constant(0.1, shape=[1]))
D_var_list = [D_W1, D_b1, D_W2, D_b2]
# 生成器使用的權重
G_W1 = tf.Variable(tf.truncated_normal([Z_dim, 100], stddev=0.1), dtype=tf.float32)
G_b1 = tf.Variable(tf.constant(0.1, shape=[100]))
G_W2 = tf.Variable(tf.truncated_normal([100, X_dim], stddev=0.1), dtype=tf.float32)
G_b2 = tf.Variable(tf.constant(0.1, shape=[X_dim]))
G_var_list = [G_W1, G_b1, G_W2, G_b2]


# 定義一個畫生成器的生成影象的函式
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.5, hspace=0.5)
    for i, sample in enumerate(samples):
        plt.subplot(gs[i])
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


# 定義鑑別其的網路
def discriminator(x):
    output1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    output2 = tf.matmul(output1, D_W2) + D_b2
    return output2


# 定義生成器網路
def genemator(x):
    output1 = tf.nn.relu(tf.matmul(x, G_W1) + G_b1)
    # 這個生成其需要使用到sigmoid
    output2 = tf.matmul(output1, G_W2) + G_b2
    output3 = tf.nn.sigmoid(output2)
    return output3


def get_perterbed_batch(minibatch):
    return minibatch + 0.5 * np.random.random(minibatch.shape)


# 定義損失函式
G_samples = genemator(Z)
D_Gsample_out = discriminator(G_samples)
D_X_out = discriminator(X)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_X_out, labels=tf.ones_like(D_X_out)))
D_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Gsample_out, labels=tf.zeros_like(D_Gsample_out)))
# 為訓練鑑別器增加一些干擾樣本的損失
alpha = tf.random_uniform(
    shape=[batch_size,1],
    minval=0.,
    maxval=1.
)
differences = X_per - X
interpolates=X+(alpha*differences)
gradients = tf.gradients(discriminator(interpolates), [interpolates])[0]
slopes=tf.sqrt(tf.reduce_sum(tf.square(gradients),reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)


sum_D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Gsample_out, labels=tf.ones_like(D_Gsample_out)))
# 定義優化器
D_train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize(sum_D_loss,
                                                                                         var_list=D_var_list)
G_train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize(G_loss, var_list=G_var_list)
# 建立一個會話並初始化所有定義的變數
session = tf.Session()
session.run(tf.global_variables_initializer())

if not op.exists("outown/"):
    os.makedirs("outown/")

# 定義儲存訓練資料的變數
plotD = []
plotG = []
i = 0
# 進行訓練迭代
for it in range(1, 200000):
    X_mb, _ = mnist.train.next_batch(batch_size)
    X_permb = get_perterbed_batch(X_mb)
    Z_mb = np.random.uniform(-1., 1., size=[batch_size, Z_dim])
    # 執行鑑別器圖
    _, D_curloss = session.run([D_train_op, sum_D_loss], feed_dict={X: X_mb, Z: Z_mb, X_per: X_permb})
    # 執行生成器圖
    _, G_curloss = session.run([G_train_op, G_loss], feed_dict={Z: Z_mb})
    plotG.append(G_curloss)
    plotD.append(D_curloss)
    if it % 1000 == 0:
        plt.subplot()
        plotnD = np.array(plotD)
        plt.plot(plotnD)
        plotnG = np.array(plotG)
        plt.plot(plotnG)
        plt.show()
        showG = np.random.uniform(-1., 1., size=[16, Z_dim])
        samples = session.run(G_samples, feed_dict={Z: showG})
        curfig = plot(samples)
        curfig.savefig("outown/{}.png".format(str(i).zfill(4)))
        print("iterate:{} ,D_loss{:.4},G_loss{:.4}".format(i, G_curloss, D_curloss))
        i += 1

其網路的損失曲線圖如下:

其最後產生的圖片如下:

看起來數字還是有模有樣的。