1. 程式人生 > >tensorflow-GPU 一步步搭建網路實現MNIST手寫識別

tensorflow-GPU 一步步搭建網路實現MNIST手寫識別

1. MNIST資料集的匯入

這裡介紹一下MNIST,MNIST是在機器學習領域中的一個經典問題。該問題解決的是把28x28畫素的灰度手寫數字圖片識別為相應的數字,其中數字的範圍從0到9.

首先我們要匯入MNIST資料集,這裡需要用到一個input_data.py檔案,在你安裝tensorflow的examples/tutorials/MNIST目錄下,如果tensorflow的目錄下沒有這個資料夾(一般是你的tensorflow版本不夠新,1.2版本有的),還請自己匯入或者更新一下tensorflow的版本,匯入的方法是在tensorflow的github(https://github.com/tensorflow/tensorflow/tree/master/tensorflow

  )下下載examples資料夾,貼上到tensorflow的根目錄下。更新tensorflow版本的話,請在ubuntu終端下執行pip install --upgrade tensorflow就可以了

好了,我們還是一步步來進行整個過程

首先我們先匯入我們需要用到的模組

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist  import  input_data

然後我們用input_data模組匯入MNIST資料集

mnist = input_data.read_data_sets('MNIST_data',one_hot = True)

上面總共下載了四個壓縮檔案,內容分別如下:
train-images-idx3-ubyte.gz    訓練集圖片 - 55000 張 訓練圖片, 5000 張 驗證圖片
train-labels-idx1-ubyte.gz      訓練集圖片對應的數字標籤
t10k-images-idx3-ubyte.gz   測試集圖片 - 10000 張 圖片
t10k-labels-idx1-ubyte.gz      測試集圖片對應的數字標籤

圖片資料將被解壓成2維的tensor:[image index, pixel index] 其中每一項表示某一圖片中特定畫素的強度值, 範圍從 [0, 255] 到

xs = tf.placeholder(tf.float32,[None,784]) # 784 = 28X28
ys = tf.placeholder(tf.float32,[None,10]) # 10 = (0~9) one_hot

[-0.5, 0.5]。 "image index"代表資料集中圖片的編號, 從0到資料集的上限值。"pixel index"代表該圖片中畫素點得個數, 從0到圖片的畫素上限值。

以train-*開頭的檔案中包括60000個樣本,其中分割出55000個樣本作為訓練集,其餘的5000個樣本作為驗證集。因為所有資料集中28x28畫素的灰度圖片的尺寸為784,所以訓練集輸出的tensor格式為[55000, 784]

執行read_data_sets()函式將會返回一個DataSet例項,其中包含了以下三個資料集。 資料集 目的 data_sets.train 55000 組 圖片和標籤, 用於訓練。 data_sets.validation 5000 組 圖片和標籤, 用於迭代驗證訓練的準確性。 data_sets.test 10000 組 圖片和標籤, 用於最終測試訓練的準確性。

具體的MNIST資料集的解壓和重構我們可以不瞭解,會用這個資料集就可以了。(當然別問我這個東西,這個過程我也不知道,嘿嘿)

這裡說一下上述程式碼中的one_hot,MNIST的標籤資料是"one-hot vectors"。 一個one-hot向量除了某一位的數字是1以外其餘各維度數字都是0。所以在此教程中,數字n將表示成一個只有在第n維度(從0開始)數字為1的10維向量。比如,標籤0將表示成([1,0,0,0,0,0,0,0,0,0,0])。

2.實踐

我們首先定義兩個佔位符,來表示訓練資料及其相應標籤資料,將會在訓練部分進行feed進去

xs = tf.placeholder(tf.float32,[None,784]) # 784 = 28X28
ys = tf.placeholder(tf.float32,[None,10]) # 10 = (0~9) one_hot
現在我們再來定義神經網路的權重和偏差

Weights = tf.Variable(tf.random_normal([784,10]))
biases = tf.Variable(tf.zeros([1,10])+0.2)
先說一下,這個神經網路是輸入直接對映到輸出,沒有隱藏層,輸入是每張影象28X28的畫素,也就是784,輸出是10個長度的向量,也就是10,所以權重是[784,10],偏差是[1,10].

y_pre = tf.nn.softmax(tf.matmul(xs,Weights)+biases) 
我們知道雖然最後的輸出結果是10個長度的向量,但他們的值可能不太直觀,打個比方,比如都是0.015之類的數,僅僅是打比方哈

為了顯示輸出結果對每個數的相應概率,我們加了一個softmax函式,它的原理很簡單,拿10個單位的向量[x0,x1,...,x9]為例,如果想知道數字0的概率是多少,用exp(x0)/(exp(x0)+exp(x1)+...+exp(x9)),其他數字的概率類似推導,你也可以參考我放在部落格上的圖片,很直觀。

cross_entropy =tf.reduce_mean( -tf.reduce_sum(ys*tf.log(y_pre),reduction_indices=[1]))#compute cross_entropy

這次的損失表示形式跟之前都不太一樣哈,這次是計算交叉熵,交叉熵是用來衡量我們的預測用於描述真相的有效性。我們可以想一想,以一張圖片為例,y_pre和ys都是一個10個長度的向量,不同的是y_pre每個序號對應的值不為0,而ys是one_hot向量,只有一個為1,其餘全為0,那麼按照上述公式,只有1對應序號i(假如是i)的log(y_pre(i))保留下來了,而且y_pre(i)越大(也就是概率越大),log(y_pre(i))越小(注意計算交叉熵前面有負號的),反之越大,符合我們對損失的概念。

我試過用官方教程的交叉熵公式,列印交叉熵時出現nan,溢位了,建議用這個好一些

train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
上面是用優化器最小化交叉熵,這裡學習率的選取很重要,官方的0.01太小,收斂得慢,還沒達到訓練損失最小就停止了,結果就是測試集誤差較大,推薦選大點,0.5左右差不多了,再大反而會發散了。

init = tf.global_variables_initializer()
上面是生出初始化init

sess  = tf.Session()
建立一個會話

sess.run(init)
初始化變數

for i in xrange(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train,feed_dict={xs:batch_xs,ys:batch_ys})
    if i %50==0:
        print sess.run(cross_entropy,feed_dict={xs:batch_xs,ys:batch_ys})

上面是程式訓練過程,這裡說一下xrange和range的區別,它們兩個的用法基本相同,但返回的型別不同,xrange返回的是生成器,range返回的是列表,所有xrange更節省記憶體,推薦用xrange,python3當中已經沒有xrange了,只有range,但它的功能和python2當中的xrange一樣

下面我們來計算計算精度

correct_prediction = tf.equal(tf.argmax(ys,1), tf.argmax(y_pre,1))

tf.argmax 是一個非常有用的函式,它能給出某個tensor物件在某一維上的其資料最大值所在的索引值。tf.argmax(y_pre,1)返回的是模型對於任一輸入x預測到的標籤值,而 tf.argmax(ys,1) 代表正確的標籤,我們可以用 tf.equal 來檢測我們的預測是否真實標籤匹配,這行程式碼返回的是匹配的布林值,成功1,失敗0

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))

tf.cast將布林型別的correct_prediction轉化成float型,然後取平均得到精確度

print sess.run(accuracy, feed_dict={xs: mnist.test.images, ys: mnist.test.labels})

精確度87.79%,官方說的91%我是沒達到過,我訓練最高不超過89%。

3.結尾

希望這篇部落格能對你的學習有所幫助,謝謝觀看!同時,有興趣的朋友可以多改改引數試試不同的結果,比如學習率,batch_size等等,這對你的理解也是有幫助的!

下一篇筆記將寫用cnn去分類MNIST資料集,敬請期待!

連結: https://pan.baidu.com/s/1c2IUlH2 密碼: cn9r