1. 程式人生 > >tensorflow使用softmax regression算法實現手寫識別

tensorflow使用softmax regression算法實現手寫識別

tutorial nbsp 書籍 數據集 feed 每一個 cast amp 類別

最近在學習黃文堅的TensorFlow書籍,希望對學習做一個總結。

softmax regression算法原理:當我們對一張圖片進行預測時,會計算每一個數字的可能性,如3的概率是3%,5的概率是6%,1的概率是80%,則返回1.

TensorFlow版本:0.8.0

# 導入手寫識別數據,TensorFlow提供了手寫識別庫
from
tensorflow.examples.tutorials.mnist import input_data
# 讀取手寫識別數據 mnist
= input_data.read_data_sets("MNIST_data/", one_hot=True)
# 訓練集數據的維度是(55000,784),訓練集標簽的維度是(55000,10)
# 測試集數據的維度是(10000,784),測試集標簽的維度是(10000,10)
# 驗證集數據的維度是(5000,784),驗證集標簽的維度是(5000,10)
# 為什麽訓練數據的維度是784?因為tensorflow提供的數據集的圖片像素是28*28=784
# 為什麽標簽的維度是10,標簽做了處理,每個預期結果變成了只包含0和1的10維數據。
# 例如標簽5就表示為[0,0,0,0,0,1,0,0,0,0],這種方法叫one-hot編碼
print(mnist.train.images.shape,mnist.train.labels.shape) print(mnist.test.images.shape,mnist.test.labels.shape) print(mnist.validation.images.shape,mnist.validation.labels.shape)

# 導入TensorFlow庫
import tensorflow as tf
# 將session註冊為默認的session,運算都在session裏跑。placeholder為輸入數據的地方
# placeholder的第一個參數表示數據類型,第2個參數表示數據的維度,None表示任意長度的數據 sess
=tf.InteractiveSession() x = tf.placeholder(tf.float32,[None,784])
# Variable用於存儲參數,它是持久化的,可以長期存在,每次叠代都會更新 # 數據的維度是784,類別的維度經過one-hot編碼後變成了10維,所以W的參數為[784,10]
# b為[10]維,W和b全部初始化為0,簡單模型的初始值不重要
W
= tf.Variable(tf.zeros([784,10])) b=tf.Variable(tf.zeros([10])) # softmax函數用於定義softmax regression算法
# matmul用於向量乘法
y
=tf.nn.softmax(tf.matmul(x,W)+b)
# 求損失函數cross-entropy,先定義一個placeholder,輸入的真實label
# cross_entropy定義了損失函數的計算方法,通過reduce_sum求熵的和,reduce_mean求每個batch的熵的平均值 y_
=tf.placeholder(tf.float32,[None,10]) cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y),reduction_indices=[1]))
# 定義一個優化器,GradientDescentOptimizer為優化器,學習率為0.5,優化目標設定為cross_entropy train_step
= tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 全局參數初始化並執行run tf.initialize_all_variables().run()
# 每次取100個樣本,並feed給placeholder,執行1000次,train_step對數據進行訓練
for i in range(1000): batch_xs,batch_ys = mnist.train.next_batch(100) train_step.run({x:batch_xs,y_:batch_ys})
# 求出概率最大的數字,判斷是否與實際標簽相符合,y是預測數據,y_是實際數據 correct_prediction
= tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# 求計算精度 accuracy
=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels})

總的來說,TensorFlow感覺還是比較簡單的,也許這只是個最簡單的模型吧。
涉及的概念也只有session,variable,placeholder,GradientDescentOptimizer。
梯度下降等復雜的方法都進行了封裝,用python不到30行的代碼就實現了手寫識別,雖然識別正確率只有92%左右。


tensorflow使用softmax regression算法實現手寫識別