1. 程式人生 > >基於CNN 的 TensorFlow Mnist 資料集實現 (另附識別單幅圖片的源程式)

基於CNN 的 TensorFlow Mnist 資料集實現 (另附識別單幅圖片的源程式)

import tensorflow as tf
import numpy as np
import mnist_inference
import mnist_train_cnn
import cv2 
import matplotlib.pyplot as plt

'''如果自己手寫的圖片是白底黑字的話,可以通過該函式將圖片灰度值反轉'''
#def reversePic(src):  
#        # 影象反轉    
#    for i in range(src.shape[0]):  
#        for j in range(src.shape[1]):  
#            src[i,j] = 255 - src[i,j]  
#    return src   

def main():
    #識別一張圖片
    sess = tf.InteractiveSession()  #定義會話    
    test_dir="E:\\TensorFlow\\Project_TF\\mnist_lenet_5\\data\\test\\3.jpg"
    x = tf.placeholder(tf.float32, (1,   #因為要識別的圖片只有一張,所以對應的batch_size為1
            mnist_inference.IMAGE_SIZE,             # 第二維和第三維表示圖片的尺寸
            mnist_inference.IMAGE_SIZE,
            mnist_inference.NUM_CHANNELS),          # 第四維表示圖片的深度,對於RBG格式的圖片,深度為5
                       name='x-input')
    y=mnist_inference.ff(x,False,None)
    y_result=tf.nn.softmax(y) #輸出層,使用softmax進行多分類  
    
    im = cv2.imread( test_dir,cv2.IMREAD_GRAYSCALE)   
        
#    im =reversePic(im)  #影象灰度值反轉,如果有需要的話
    plt.matshow(im)
    im = cv2.resize(im,(28,28),interpolation=cv2.INTER_CUBIC) #圖片預處理,統一成28*28
    x_img = np.reshape(im , (1,28,28,1))  #將圖片變成四維的
    
    '''重新載入模型儲存的引數'''
    variable_averages=tf.train.ExponentialMovingAverage(mnist_train_cnn.MOVING_AVERAGE_DECAY)
    variable_to_restore=variable_averages.variables_to_restore()
    saver=tf.train.Saver(variable_to_restore)
    
    ckpt=tf.train.get_checkpoint_state(mnist_train_cnn.MODEL_SAVE_PATH)
    saver.restore(sess,ckpt.model_checkpoint_path)
    
    output = sess.run( y_result , feed_dict = {x:x_img})
    print ('the y_con :   ', '\n',output ) #輸出對應每個數的概率
    print ('the predict is : ', np.argmax(output)) #輸出最大概率所對應的標籤
    #關閉會話  
    sess.close()  
if (__name__ == '__main__'):  
    main()