1. 程式人生 > >變分自編碼網路的實現

變分自編碼網路的實現

1、VAE跟Gan有點類似,都是可以通過一些輸入,生成一些樣本資料。不同點是VAE是假設在資料的分佈是服從正態分佈的,而GAN是沒有這個假設的,完全是由資料驅動,進行訓練得出規律的。

下面是變分自編碼網路的程式碼:

import numpy as np
import tensorflow as tf
import tensorflow.contrib as contrib
from tensorflow.contrib.layers import fully_connected
import tensorflow.examples.tutorials.mnist as mnist
import functiontool as functiontool

# 定義一些全域性變數
n_inputs = 28 * 28
n_hidden1 = 500
n_hidden2 = 500
n_hiddenmiddle = 30
n_hidden3 = n_hidden2
n_hidden4 = n_hidden1
n_outputs = n_inputs
learning_rate = 0.001
Minst = mnist.input_data.read_data_sets("MNIST_data/")

# 定義網路的結構
with contrib.framework.arg_scope([fully_connected], activation_fn=tf.nn.elu, weights_initializer=
contrib.layers.variance_scaling_initializer()):
    X = tf.placeholder(dtype=tf.float32, shape=[None, n_inputs])
    hidden1 = fully_connected(X, n_hidden1)
    hidden2 = fully_connected(hidden1, n_hidden2)
    hiddenmiddle_mean = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None)
    hiddenMiddle_gamma = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None)
    hiddenMiddel_sigmar = tf.exp(0.5 * hiddenMiddle_gamma)
    noise = tf.random_normal(tf.shape(hiddenMiddel_sigmar))
    hiddemiddle = hiddenmiddle_mean + hiddenMiddel_sigmar * noise
    hidden3 = fully_connected(hiddemiddle, n_hidden3)
    hidden4 = fully_connected(hidden3, n_hidden4)
    logits = fully_connected(hidden4, n_outputs, activation_fn=None)
    outputs = tf.sigmoid(logits)
# 定義損失函式
restruction_loss =tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=logits))
latent_loss = 0.5 * tf.reduce_sum(tf.exp(hiddenMiddle_gamma) + tf.square(hiddenmiddle_mean) - 1 - hiddenMiddle_gamma)
sum_loss = restruction_loss + latent_loss
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_optimizer = optimizer.minimize(sum_loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
# 定義網路的訓練
n_epochs = 60
n_batch = 150
with tf.Session() as session:
    init.run()
    for i in range(n_epochs):
        batch_nums = Minst.train.num_examples // n_batch
        for batch_size in range(batch_nums):
            print("\r{}%".format(100 * batch_size // batch_nums), end="")
            X_trian, Y_train = Minst.train.next_batch(n_batch)
            session.run(train_optimizer, feed_dict={X: X_trian})
        loss_val = sum_loss.eval(feed_dict={X: X_trian})
        print("\rTrain loss:{}".format(loss_val))
        saver.save(session, "weight/VaAuto.cpkt")
    test_rng = np.random.normal(size=(10, n_hiddenmiddle))
    out_val = outputs.eval(feed_dict={hiddemiddle: test_rng})
    functiontool.show_reconstructed_digits_old(out_val)

其畫圖的函式為:


def show_reconstructed_digits_old(outputs):
    dimsize = outputs.shape[0]
    plt.figure(figsize=(8, 50))
    for i in range(outputs.shape[0]):
        plt.subplot(outputs.shape[0], 1, i + 1)
        plot_image(outputs[i])
    plt.show()

得出的訓練結果是: