對抗網路的簡單版--手寫數字MNIST的訓練
阿新 • • 發佈:2018-12-10
1、對抗網路是一種資料驅動的網路。人為干預比較少。
其中生成網路的損失利用了鑑別器的損失。而鑑別器的資料資料輸入利用了生成網路的生成資料跟真實資料。
兩個網路再權重更新是互不干擾。都只更新自身的權重值。
下面是簡單gan網路的實現程式碼:
import tensorflow as tf import numpy as np import tensorflow.examples.tutorials.mnist as input_data import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import os.path as op import os mnist = input_data.input_data.read_data_sets("MNIST_data/", one_hot=True) X_dim = 784 Z_dim = 100 batch_size = 128 # 真實的輸入影象的佔位符 X = tf.placeholder(tf.float32, shape=[None, X_dim]) # 輸入的用於生成器的輸入資料的佔位符 Z = tf.placeholder(tf.float32, shape=[None, Z_dim]) # 干擾樣本 X_per = tf.placeholder(tf.float32, shape=[None, X_dim]) # 鑑別器使用的權重 D_W1 = tf.Variable(tf.truncated_normal([X_dim, 100], stddev=0.1), dtype=tf.float32) D_b1 = tf.Variable(tf.constant(0.1, shape=[100])) D_W2 = tf.Variable(tf.truncated_normal([100, 1], stddev=0.1), dtype=tf.float32) D_b2 = tf.Variable(tf.constant(0.1, shape=[1])) D_var_list = [D_W1, D_b1, D_W2, D_b2] # 生成器使用的權重 G_W1 = tf.Variable(tf.truncated_normal([Z_dim, 100], stddev=0.1), dtype=tf.float32) G_b1 = tf.Variable(tf.constant(0.1, shape=[100])) G_W2 = tf.Variable(tf.truncated_normal([100, X_dim], stddev=0.1), dtype=tf.float32) G_b2 = tf.Variable(tf.constant(0.1, shape=[X_dim])) G_var_list = [G_W1, G_b1, G_W2, G_b2] # 定義一個畫生成器的生成影象的函式 def plot(samples): fig = plt.figure(figsize=(4, 4)) gs = gridspec.GridSpec(4, 4) gs.update(wspace=0.5, hspace=0.5) for i, sample in enumerate(samples): plt.subplot(gs[i]) plt.imshow(sample.reshape(28, 28), cmap='Greys_r') return fig # 定義鑑別其的網路 def discriminator(x): output1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1) output2 = tf.matmul(output1, D_W2) + D_b2 return output2 # 定義生成器網路 def genemator(x): output1 = tf.nn.relu(tf.matmul(x, G_W1) + G_b1) # 這個生成其需要使用到sigmoid output2 = tf.matmul(output1, G_W2) + G_b2 output3 = tf.nn.sigmoid(output2) return output3 def get_perterbed_batch(minibatch): return minibatch + 0.5 * np.random.random(minibatch.shape) # 定義損失函式 G_samples = genemator(Z) D_Gsample_out = discriminator(G_samples) D_X_out = discriminator(X) D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_X_out, labels=tf.ones_like(D_X_out))) D_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Gsample_out, labels=tf.zeros_like(D_Gsample_out))) # 為訓練鑑別器增加一些干擾樣本的損失 alpha = tf.random_uniform( shape=[batch_size,1], minval=0., maxval=1. ) differences = X_per - X interpolates=X+(alpha*differences) gradients = tf.gradients(discriminator(interpolates), [interpolates])[0] slopes=tf.sqrt(tf.reduce_sum(tf.square(gradients),reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes-1.)**2) sum_D_loss = D_loss_real + D_loss_fake G_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Gsample_out, labels=tf.ones_like(D_Gsample_out))) # 定義優化器 D_train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize(sum_D_loss, var_list=D_var_list) G_train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize(G_loss, var_list=G_var_list) # 建立一個會話並初始化所有定義的變數 session = tf.Session() session.run(tf.global_variables_initializer()) if not op.exists("outown/"): os.makedirs("outown/") # 定義儲存訓練資料的變數 plotD = [] plotG = [] i = 0 # 進行訓練迭代 for it in range(1, 200000): X_mb, _ = mnist.train.next_batch(batch_size) X_permb = get_perterbed_batch(X_mb) Z_mb = np.random.uniform(-1., 1., size=[batch_size, Z_dim]) # 執行鑑別器圖 _, D_curloss = session.run([D_train_op, sum_D_loss], feed_dict={X: X_mb, Z: Z_mb, X_per: X_permb}) # 執行生成器圖 _, G_curloss = session.run([G_train_op, G_loss], feed_dict={Z: Z_mb}) plotG.append(G_curloss) plotD.append(D_curloss) if it % 1000 == 0: plt.subplot() plotnD = np.array(plotD) plt.plot(plotnD) plotnG = np.array(plotG) plt.plot(plotnG) plt.show() showG = np.random.uniform(-1., 1., size=[16, Z_dim]) samples = session.run(G_samples, feed_dict={Z: showG}) curfig = plot(samples) curfig.savefig("outown/{}.png".format(str(i).zfill(4))) print("iterate:{} ,D_loss{:.4},G_loss{:.4}".format(i, G_curloss, D_curloss)) i += 1
其網路的損失曲線圖如下:
其最後產生的圖片如下:
看起來數字還是有模有樣的。