tensorflow實戰入門題目--手寫數字的識別
阿新 • • 發佈:2018-12-02
tensorflow實戰入門題目–手寫數字的識別
這是我的第一篇部落格,有很多寫不好的地方,還請大家多多批評指正。
手寫體的示範:
匯入資料,由於mnist手寫數字在
from tensorflow.examples.tutorials.mnist import input_data #載入Mnist,由於比較出名,所以在tensorflow裡面已有這個資料集
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
輸出訓練集、測試集、驗證集的維度
print(mnist.train. images.shape,mnist.train.labels.shape) #訓練集的形狀,維度
print(mnist.test.images.shape,mnist.test.labels.shape) #測試集的形狀、維度
print(mnist.validation.images.shape,mnist.validation.labels.shape) #驗證集的形狀、維度
計算圖:
引入tensorflow框架,並建立一個會話
imp1ort tensorflow as tf #引入tensorflow的框架
sess = tf.InteractiveSession ()#建立一個會話
輸入x的資料,給權值w和basic賦初始值(此處賦初始值為零)
x = tf.placeholder(tf.float32,[None,784])#x的輸入資料,定義型別為float32,None表示輸入的行數不確定,佔位符
w = tf.Variable(tf.zeros([784,10])) #初始化權值w的值 784*10的全為0的向量
b = tf.Variable(tf.zeros([10])) #初始化bias的值
使用Softmax Regression函式
y = tf.nn.softmax(tf.matmul(x,w) + b) #用softmax Regression 輸出這張照片對應數字的百分比, 維度(None,10),且百分比的和為一
# Softmax Regression經常用於分類問題的解決
輸入訓練集資料的真正的測試結果
y_ = tf.placeholder("float",[None,10]) #輸入訓練集真實的結果(q代表的數字)
使用coress_entropy定義損失函式
y_ = tf.placeholder("float",[None,10]) #輸入訓練集真實的數字
#計算交叉熵,定義損失函式 tf.reduce_mean對於每個測試資料求均值,tf.reduce_sum進行求和
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
優化演算法–此處選擇隨機梯度下降的演算法SGD
#優化演算法,這裡採用梯度下降,反向傳播,學習率為0.5,優化目標為cross_entropy
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
迭代訓練
tf.global_variables_initializer().run() #初始化全域性所有的引數#迭代訓練,每次隨機選擇100個訓練資料
for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100)#隨機抽取100組資料
train_step.run({x:batch_xs,y_:batch_ys }) #此處進行封裝
計算準確率
correct_predicition = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) #輸出結果為bool型·
accuracy = tf.reduce_mean(tf.cast(correct_predicition,tf.float32)) #將bool型轉換為float32,並求平均數
print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})) #計算準確率,eval執行整個Session()
tensorflow,我們定義的各個公式其實是Computation Graph,在執行這行程式碼的時候,計算還沒有實際發生,只有等呼叫run的方法,並feed資料的時計算才真正的執行。比如cross、entropy、train、step、accuracy等都是計算圖(tensorflow會自動建立)中的節點,而不是資料結果,我可以通過run方法執行這些節點或者說運算操作來獲取結果。