1. 程式人生 > >MNIST手寫體識別--tensorflow

MNIST手寫體識別--tensorflow


 

MNIST手寫體識別--tensorflow

對於tensorflow給出的幾個版本的手寫體識別的程式碼進行分析。其中tensorflow的mnist程式碼在https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist

1:softmax版本

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A very simple MNIST classifier.

See extensive documentation at
https://www.tensorflow.org/get_started/mnist/beginners
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None


def main(_):
 # Import data
 mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

 # Create the model
 x = tf.placeholder(tf.float32, [None, 784])
 W = tf.Variable(tf.zeros([784, 10]))
 b = tf.Variable(tf.zeros([10]))
 y = tf.matmul(x, W) + b

 # Define loss and optimizer
 y_ = tf.placeholder(tf.float32, [None, 10])

 # The raw formulation of cross-entropy,
 #
 # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)),
 # reduction_indices=[1]))
 #
 # can be numerically unstable.
 #
 # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw
 # outputs of 'y', and then average across the batch.
 cross_entropy = tf.reduce_mean(
 tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

 sess = tf.InteractiveSession()
 tf.global_variables_initializer().run()
 # Train
 for _ in range(1000):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

 # Test trained model
 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 print(sess.run(accuracy, feed_dict={x: mnist.test.images,
 y_: mnist.test.labels}))

if __name__ == '__main__':
 parser = argparse.ArgumentParser()
 parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
 help='Directory for storing input data')
 FLAGS, unparsed = parser.parse_known_args()
 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

上面就是完整的程式碼,下面是分析。
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
這裡引入的是input_data然後得到資料集,但是其實這裡input_data只是一個為了更加形象化代表得到資料集的模組,input_data裡面從mnist.py裡面匯入了read_data_sets方法,而mnist.py包含了獲取手寫體的所有方法。首先,mnist.py會檢測是否已經下載好了資料集,根據上面的程式碼,可以看到預設檢測的地址是/tmp/tensorflow/mnist/input_data,其實這裡如果不是linux型別的機器,那麼就是該驅動盤,比如說程式碼在e盤下面,那麼檔案及地址就是e:/tmp/tensorflow/mnist/input_data。如果下載好了,那麼就會解壓資料集,轉化為image-[index, y, x, depth],label-[index]的形式,並且進行one-hot操作。每個檔案都存在一定的格式,應該是 2051 num_image rows cols data的順序,num_image是圖片大小,然後行列個數。最後得到的就是train(60000,784),test(10000,784)資料集,數值基於0到1之間。這裡實際上定義的就是一個二維陣列,有60000行,784列。每個都包含images和labels.通過mnist.train.images和mnist.train.labels等來進行索取。labels使用one-hot進行處理的,所有如果手寫體分類存在10類的話,labels就是一個10列的陣列。
# Create the model
 x = tf.placeholder(tf.float32, [None, 784])
 W = tf.Variable(tf.zeros([784, 10]))
 b = tf.Variable(tf.zeros([10]))
 y = tf.matmul(x, W) + b
然後定義輸入,輸出以及引數W,b 可以看到,W*x數組裡面每一行其實就是對於每個樣本的預測,x1w1的值代表結果是0的概率,當然僅僅進行到這裡值不是概率,需要經過處理之後才是概率。 定義好模型之後,需要設定目標函式來進行優化。這裡使用的是一種稱為交叉熵的量進行優化。交叉熵是資訊理論裡面的知識,主要用於度量兩個概率分佈間的差異性資訊。
cross_entropy = tf.reduce_mean(
 tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
 train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
定義好目標函式,使用梯度下降方法優化。定義好優化方法之後,只需要不斷地優化直到收斂就可以停止迭代。因此,現在就可以運行了。
 sess = tf.InteractiveSession()
 tf.global_variables_initializer().run()
定義的變數必須要使用顯示的初始化函式。
 for _ in range(1000):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
主要的優化過程在這裡,外層迴圈代表的是梯度下降過程,每一次迭代相當於一步,使用的100個樣本點的作用是小批梯度,因為批梯度演算法每次更新使用的是全部的樣本點,這樣導致計算複雜度很高。 所以這裡使用的是小樣本來進行每次的迭代。需要說明的一點是y_其實就是標籤。後面的程式碼用於評估模型,這裡不做論述。還有一點,就是在定義x的時候使用的是[None,784],原因是後面再迭代的時候需要傳入x,而使用None的話,傳入任何大小都行。    

檢視原文: http://www.hahaszj.top/uncategorized/mnist%e6%89%8b%e5%86%99%e4%bd%93%e8%af%86%e5%88%ab-tensorflow/186