DenoisingAutoencoder(影象去噪自動編碼器)
阿新 • • 發佈:2018-12-25
本文主要介紹使用TensorFlow實現DenoisingAutoencoder(影象去噪自動編碼器)。
下面是示例程式碼:
# 匯入相關模組 import numpy as np import sys import tensorflow as tf import matplotlib.pyplot as plt ''' IPython有一組預定義的“魔術函式”,您可以使用命令列樣式語法呼叫 它們。有兩種魔法,一種是線導向(line-oriented),另一種是單元 導向(cell-oriented)。line magics以%字元作為字首,其工作方式 與作業系統命令列呼叫非常相似:它們作用於整行,line magics可以返 回結果,也可以進行賦值使用;cell magics是以%%開頭,它需要出現 在單元的第一行,而且是作用於整個單元。 使用此方法時,繪製命令的輸出將在前端顯示,就像Jupyter筆記本一樣 ,直接顯示在生成命令的程式碼單元格的下方,生成的繪圖也將儲存在筆記 本文件中。不過這個方法好像只適用於Jupyter notebook和Jupyter QtConsole。 ''' %matplotlib inline # 匯入資料集 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) inputs_ = tf.placeholder(tf.float32, [None, 28, 28, 1]) targets_ = tf.placeholder(tf.float32, [None, 28, 28, 1]) def lrelu(x, alpha=0.1): return tf.maximum(alpha * x, x) ### Encoder with tf.name_scope('en-convolutions'): conv1 = tf.layers.conv2d(inputs_, filters=32, kernel_size=(3, 3), strides=(1, 1), padding='SAME', use_bias=True, activation=lrelu,) # now 28x28x32 with tf.name_scope('en-pooling'): maxpool1 = tf.layers.max_pooling2d(conv1, pool_size=(2, 2), strides=(2,2),) # now 14x14x32 with tf.name_scope('en-convolutions'): conv2 = tf.layers.conv2d(maxpool1, filters=32, kernel_size=(3, 3), strides=(1,1), padding='SAME', use_bias=True, activation=lrelu,) # now 14x14x32 with tf.name_scope('encoding'): encoded = tf.layers.max_pooling2d(conv2, pool_size=(2,2), strides=(2,2),) # now 7x7x32 ### Decoder with tf.name_scope('decoder'): conv3 = tf.layers.conv2d(encoded, filters=32, kernel_size=(3, 3), strides=(1,1), padding='SAME', use_bias=True, activation=lrelu) # 7x7x32 upsamples1 = tf.layers.conv2d_transpose(conv3, filters=32, kernel_size=3, padding='SAME', strides=2, name='upsample1') # now 14x14x32 upsamples2 = tf.layers.conv2d_transpose(upsamples1, filters=32, kernel_size=3, padding='SAME', strides=2, name='upsamples2') # now 28x28x32 logits = tf.layers.conv2d(upsamples2, filters=1, kernel_size=(3, 3), strides=(1, 1), name='logits', padding='SAME', use_bias=True) # now 28x28x1 # 通過sigmoid傳遞logits以獲得重建影象 decoded = tf.sigmoid(logits, name='recon') # 定義損失函式和優化器 loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=targets_) learning_rate = tf.placeholder(tf.float32) cost = tf.reduce_mean(loss) opt = tf.train.AdamOptimizer(learning_rate).minimize(cost) # 訓練 sess = tf.Session() saver = tf.train.Saver() loss = [] valid_loss = [] display_step = 1 epochs = 25 batch_size = 64 lr =1e-5 sess.run(tf.global_variables_initializer()) writer = tf.summary.FileWriter('./graphs', sess.graph) for e in range(epochs): total_batch = int(mnist.train.num_examples / batch_size) for ibatch in range(total_batch): batch_x = mnist.train.next_batch(batch_size) batch_test_x = mnist.test.next_batch(batch_size) imgs_test = batch_x[0].reshape((-1, 28, 28, 1)) noise_factor = 0.5 x_test_noisy = imgs_test + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=imgs_test.shape) x_test_noisy = np.clip(x_test_noisy, 0., 1.) imgs = batch_x[0].reshape((-1, 28, 28, 1)) x_train_noisy = imgs + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=imgs.shape) x_train_noisy = np.clip(x_train_noisy, 0., 1.) batch_cost, _ = sess.run([cost, opt], feed_dict={inputs_: x_train_noisy, targets_: imgs,learning_rate:lr}) batch_cost_test = sess.run(cost, feed_dict={inputs_: x_test_noisy, targets_: imgs_test}) if (e+1) % display_step == 0: print("Epoch: {}/{}...".format(e+1, epochs), "Training loss: {:.4f}".format(batch_cost), "Validation loss: {:.4f}".format(batch_cost_test)) loss.append(batch_cost) valid_loss.append(batch_cost_test) plt.plot(range(e+1), loss, 'bo', label='Training loss') plt.plot(range(e+1), valid_loss, 'r', label='Validation loss') plt.title('Training and validation loss') plt.xlabel('Epochs ',fontsize=16) plt.ylabel('Loss',fontsize=16) plt.legend() plt.figure() plt.show() saver.save(sess, 'encode_model') batch_x= mnist.test.next_batch(10) imgs = batch_x[0].reshape((-1, 28, 28, 1)) noise_factor = 0.5 x_test_noisy = imgs + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=imgs.shape) x_test_noisy = np.clip(x_test_noisy, 0., 1.) recon_img = sess.run([decoded], feed_dict={inputs_: x_test_noisy})[0] plt.figure(figsize=(20, 4)) plt.title('Reconstructed Images') print("Original Images") for i in range(10): plt.subplot(2, 10, i+1) plt.imshow(imgs[i, ..., 0], cmap='gray') plt.show() plt.figure(figsize=(20, 4)) print("Noisy Images") for i in range(10): plt.subplot(2, 10, i+1) plt.imshow(x_test_noisy[i, ..., 0], cmap='gray') plt.show() plt.figure(figsize=(20, 4)) print("Reconstruction of Noisy Images") for i in range(10): plt.subplot(2, 10, i+1) plt.imshow(recon_img[i, ..., 0], cmap='gray') plt.show() writer.close() sess.close()