1. 程式人生 > >tensorflow-顯示mnist影象並且載入模型識別單張影象(二)

tensorflow-顯示mnist影象並且載入模型識別單張影象(二)

通過上一遍文章,我們能夠得到比較簡單的mnist訓練模型。

在根目錄的save資料夾下有四個檔案,儲存的是訓練模型,檔案具體內容自行查詢資料,我們載入模型時,只需定義出save資料夾下的路徑即可

下面程式碼包含:

一:從測試集中隨機挑選出兩張影象用於顯示並且識別

二:載入訓練模型 

import tensorflow as tf 
#載入mnist庫,從在測試集中挑選要測試的圖片
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
import pylab#用於畫圖,很方便
########################################################################
pylab.mpl.rcParams['font.sans-serif'] = ['SimHei'] # 若不新增,中文無法在圖中顯示
# import matplotlib
# matplotlib.rcParams['axes.unicode_minus']=False # 若不新增,無法在圖中顯示負號
###########################################################################
 

tf.reset_default_graph()#可以清空預設圖裡所有的節點。
#輸入測試資料
x = tf.placeholder(tf.float32, [None, 784]) #測試圖片

#權重和偏置
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
 
#構建模型
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分類
 
#儲存或者開啟模型
saver = tf.train.Saver()
#儲存或者開啟模型的路徑
model_path = "save/model"
###############################################################################
#啟動會議
with tf.Session() as sess:
    #變數初始化
    sess.run(tf.global_variables_initializer())
    #開啟訓練好的模型
    saver.restore(sess, model_path)
    #測試模型
    #從測試集中隨機取2張圖片,圖片賦給batch_xs,對應的標籤賦給batch_ys
    batch_xs,batch_ys = mnist.test.next_batch(2)
    #output為2張圖片通過softmax得到的最大概率對應的標籤
    output = tf.argmax(pred, 1)
    #正式執行,先X輸入2張測試圖片,再output得到2張圖片概率最大對應的標籤並賦給outputval
    #最後pred得到2張圖片再0-9上各自的概率
    outputval, predv = sess.run([output, pred], feed_dict={x: batch_xs})

    print(outputval, predv)
    #######################################################################
    print(batch_xs.shape)
    pylab.subplot(121) 
    im = batch_xs[0]
    im = im.reshape(-1, 28)#把原本在mnist中為一行的資料變成二維的28列矩陣,-1:不用指定具體為多少行
    pylab.title('該圖片中的數字為:'+ str(outputval[0]))
    pylab.imshow(im)
 
    pylab.subplot(122)
    im = batch_xs[1]
    im = im.reshape(-1, 28)
    pylab.title('該圖片中的數字為:' + str(outputval[1]))
    pylab.imshow(im)
    pylab.show()

結果: