1. 程式人生 > >Tensorflow之MNIST手寫數字識別:分類問題(2)

Tensorflow之MNIST手寫數字識別:分類問題(2)

整體程式碼:

#資料讀取
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

#定義待輸入資料的佔位符
#mnist中每張照片共有28*28=784個畫素點
x = tf.placeholder(tf.float32,[None,784],name="
X") #0-9一共10個數字=>10個類別 y = tf.placeholder(tf.float32,[None,10],name="Y") #定義模型變數 #以正態分佈的隨機數初始化權重W,以常數0初始化偏置b #在神經網路中,權值W的初始值通常設為正態分佈的隨機數,偏置項b的初始值通常也設定為正態分佈的隨機數或常數。 W = tf.Variable(tf.random_normal([784,10],name="W")) b = tf.Variable(tf.zeros([10]),name="b") #用單個神經元構建神經網路 forward=tf.matmul(x,W) + b #
前向計算 #結果分類 #當我們處理多分類任務的時候,通常需要使用Softmax Regression模型。Softmax會對每一類別估算出一個概率。 #工作原理:將判定為某一類的特徵相加,然後將這些特徵轉化為判定是這一類的概率 pred = tf.nn.softmax(forward) #Softmax分類 #設定訓練引數 train_epochs = 120 #訓練輪數 batch_size = 120 #單次訓練樣本數(批次大小) total_batch = int(mnist.train.num_examples/batch_size) #
一輪訓練有多少批次 display_step = 1 #顯示粒度 learning_rate = 0.01 #學習率 #概率估算值需要將預測輸出值控制在[0,1]區間內。二元分類問題的目標是正確預測兩個可能標籤中的一個 #邏輯迴歸可以用於處理這類問題。二元邏輯迴歸的損失函式一般採用對數損失函式 #多元分類:邏輯迴歸可生成介於0到1.0之間的小數。Softmax將這一想法延伸到多類別領域。 #在多類別問題中,Softmax會為每個類別分配一個用小數表示的概率。這些用小數表示的概率相加之和必須是1.0 #交叉熵損失函式:交叉熵是一個資訊理論的概念,它原來是用來估算平均編碼長度的。 #交叉熵刻畫的是兩個概率分佈之間的距離,p代表正確答案,q代表的預測值,交叉熵越小,兩個概率的分佈越接近 #定義損失函式 loss_function = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) #交叉熵 #選擇優化器 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss_function) #梯度下降優化器 #定義準確率 # 檢查預測類別tf.argmax(pred,1)與實際類別tf.argmax(y,1)的匹配情況 #argmax()將陣列中最大值的下標取出來 correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) #準確率,將布林值轉化為浮點數,並計算平均值 tf.cast()將布林值投射成浮點數 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #宣告會話,初始化變數 sess = tf.Session() init = tf.global_variables_initializer() #變數初始化 sess.run(init) #訓練模型 for epoch in range(train_epochs): for batch in range(total_batch): xs,ys = mnist.train.next_batch(batch_size) #讀取批次資料 sess.run(optimizer,feed_dict={x:xs,y:ys}) #執行批次訓練 #total_batch個批次訓練完成後,使用驗證資料計算誤差與準確率,驗證集沒有分批 loss,acc = sess.run([loss_function,accuracy],feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) #列印訓練過程中的詳細資訊 if (epoch+1) % display_step == 0: print("Train Epoch:",'%02d'%(epoch+1),"Loss=","{:.9f}".format(loss),"Accuracy=","{:.4f}".format(acc)) print("Train Finished!") #評估模型 #完成訓練後,在測試集上評估模型的準確率 accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels}) print("Test Accuracy:",accu_test) #完成訓練後,在驗證集上評估模型的準確率 accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels}) print("Test Accuracy:",accu_validation) #完成訓練後,在訓練集上評估模型的準確率 accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels}) print("Test Accuracy:",accu_train) #應用模型 #在建立模型並進行訓練後,若認為準確率可以接受,則可以使用此模型進行預測 #由於pred預測結果是one_hot編碼格式,所以需要轉換成0~9數字 prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images}) #檢視預測結果中的前10項 prediction_result[0:10] #定義視覺化函式 def plot_images_labels_prediction(images,labels,prediction,index,num=10): #引數: 圖形列表,標籤列表,預測值列表,從第index個開始顯示,預設一次顯示10幅 fig = plt.gcf() #獲取當前圖表,Get Current Figure fig.set_size_inches(10,12) #1英寸等於2.45cm if num > 25 : #最多顯示25個子圖 num = 25 for i in range(0,num): ax = plt.subplot(5,5,i+1) #獲取當前要處理的子圖 ax.imshow(np.reshape(images[index],(28,28)), cmap = 'binary') #顯示第index個影象 title = "labels="+str(np.argmax(labels[index])) #構建該圖上要顯示的title資訊 if len(prediction)>0: title += ",predict="+str(prediction[index]) ax.set_title(title,fontsize=10) #顯示圖上的title資訊 ax.set_xticks([]) #不顯示座標軸 ax.set_yticks([]) index += 1 plt.show() #視覺化預測結果 plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)