1. 程式人生 > >[TensorFlow深度學習入門]實戰十二·使用DNN網路實現自動編碼器

[TensorFlow深度學習入門]實戰十二·使用DNN網路實現自動編碼器

[TensorFlow深度學習入門]實戰十二·使用DNN網路實現自動編碼器

  • 測試程式碼
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

fashion_mnist = keras.datasets.fashion_mnist

(train_images, train_labels)
, (test_images, test_labels) = fashion_mnist.load_data() train_images = train_images / 255.0 test_images = test_images / 255.0 """plt.figure() plt.imshow(train_images[0]) plt.colorbar() plt.grid(False) plt.show()""" print(train_images.shape) x = tf.placeholder(tf.float32,[None,28,28]) y_ = tf.placeholder(
tf.float32,[None,28,28]) xs = tf.reshape(x,shape=[-1,28*28]) y_shape = tf.reshape(y_,shape=[-1,28*28]) with tf.variable_scope("encoder") as scope1: w1 = tf.get_variable("w1",initializer=tf.random_normal([784,128],stddev=1)) w2 = tf.get_variable("w2",initializer=tf.random_normal([128,16],stddev=
1)) b1 = tf.get_variable("b1",initializer=tf.zeros([1,128])+0.01) b2 = tf.get_variable("b2",initializer=tf.zeros([1,16])+0.01) l1 = tf.nn.sigmoid(tf.matmul(xs,w1)+b1) l2 = tf.nn.sigmoid(tf.matmul(l1,w2)+b2) with tf.variable_scope("decoder") as scope2: w3 = tf.get_variable("w3",initializer=tf.random_normal([16,128],stddev=1)) w4 = tf.get_variable("w4",initializer=tf.random_normal([128,784],stddev=1)) b3 = tf.get_variable("b3",initializer=tf.zeros([1,128])+0.01) b4 = tf.get_variable("b4",initializer=tf.zeros([1,784])+0.01) l3 = tf.nn.sigmoid(tf.matmul(l2,w3)+b3) y = tf.nn.sigmoid(tf.matmul(l3,w4)+b4) loss = tf.reduce_mean(tf.square(y-y_shape)) opt = tf.train.AdamOptimizer(0.05).minimize(loss) with tf.Session() as sess: init = tf.global_variables_initializer() srun = sess.run srun(init) for e in range(3001): ts = e*100%60000 loss_val,_ = srun([loss,opt],{x:train_images[ts:ts+100],y_:train_images[ts:ts+100]}) if(e%100==0): print(e,loss_val) y_out = srun(y,{x:test_images[:1]}) y_show = tf.reshape(y_out,shape=[-1,28,28]) y_show = y_show.eval() print(type(y_show)) print((y_show.shape)) print((train_images[0].shape)) plt.figure() plt.imshow(y_show[0]) #plt.imshow(test_images[0]) plt.colorbar() plt.grid(False) plt.show() plt.figure() plt.imshow(test_images[0]) plt.colorbar() plt.grid(False) plt.show()
  • 輸出結果
0 0.3734877
100 0.07903126
200 0.069011234
300 0.06310683
400 0.05900423
500 0.05826289
600 0.057095125
700 0.05598061
800 0.053362366
900 0.049294725
1000 0.05573253
1100 0.049589165
1200 0.052580144
1300 0.050453667
1400 0.052007377
1500 0.04682982
1600 0.052490402
1700 0.047567524
1800 0.05900064
1900 0.051574014
2000 0.049422733
2100 0.046142038
2200 0.05289932
2300 0.049340945
2400 0.051445786
2500 0.0516909
2600 0.049265135
2700 0.04827474
2800 0.04997604
2900 0.04606307
3000 0.050183807