1. 程式人生 > >TensorFlow -- 訓練MNIST資料集

TensorFlow -- 訓練MNIST資料集

# -*- coding:utf-8 -*-

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import pylab

# 下載並解壓資料
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)

# 重置圖
tf.reset_default_graph()

# 定義佔位符
# 資料集的維度是28*28=874
x = tf.placeholder(tf.float32, [None, 784])
# 共10個類別
y = tf.placeholder(tf.float32, [None, 10])

# 定義學習引數
W = tf.Variable(tf.random_normal(([784, 10])))
b = tf.Variable(tf.zeros([10]))

# 正向傳播
# softmax分類器
pred = tf.nn.softmax(tf.matmul(x, W) + b)

# 反向結構
# 損失函式
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))

# 定義引數
learning_rate = 0.01
# 使用梯度下降優化器
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)

# 迭代次數
training_epochs = 25
batch_size = 100
display_step = 1

saver = tf.train.Saver()
model_path = 'log/521model.ckpt'

# 啟動session
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # 啟動迴圈開始訓練
    for epoch in range(training_epochs):
        avg_cost = 0

        # total_batch=550
        total_batch = int(mnist.train.num_examples/batch_size)
        # 迴圈所有的資料集
        for i in range(total_batch):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            # 執行優化器
            _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys})

            # 計算平均的loss值
            avg_cost += c / total_batch

        # 顯示訓練中的詳細資訊
        if (epoch+1) % display_step == 0:
            print('Epochs:', '%04d' % (epoch+1), 'cost:', '{:.9f}'.format(avg_cost))
    print('Finished!')

    # 測試模型
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    # 計算精確度
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accurary:', accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

    # 儲存模型
    save_path = saver.save(sess, model_path)
    print('Model saved in file: %s' % save_path)

# 讀取模型
print('Starting loading model')
with tf.Session() as sess:
    # 初始化變數
    sess.run(tf.global_variables_initializer())
    # 恢復模型變數
    saver.restore(sess, model_path)

    # 測試模型
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print('Accuracy:', accuracy.eval({x: mnist.test.images, y:mnist.test.labels}))

    output = tf.argmax(pred, 1)
    batch_xs, batch_ys = mnist.train.next_batch(2)
    outputval, predv = sess.run([output, pred], feed_dict={x: batch_xs, y:batch_ys})
    print(outputval, predv, batch_ys)

    im = batch_xs[0]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()

    im = batch_xs[1]
    im = im.reshape(-1, 28)
    pylab.imshow(im)
    pylab.show()