1. 程式人生 > >TensorFlow在MNIST中的應用 識別手寫數字(OpenCV+TensorFlow+CNN)

TensorFlow在MNIST中的應用 識別手寫數字(OpenCV+TensorFlow+CNN)

參考:

1.《TensorFlow技術解析與實戰》

2. http://blog.csdn.net/sparta_117/article/details/66965760

3. http://blog.csdn.net/HelloZEX/article/details/78537213

4. http://blog.csdn.net/gaohuazhao/article/details/72886450

5.http://blog.csdn.net/skeeee/article/details/16844937

###################################################

學習TF已經有一段時間了 ,

《TensorFlow技術解析與實戰》介紹的TF也還算詳盡,參考眾多大牛部落格後,就跟著實現一遍識別自己手寫數字的識別程式好了。學習過程就是在模仿中提高的嘛。手寫原圖:


內容如下: 
Tensorflow和MNIST簡介
CNN演算法
訓練程式
寫數字,並用Opencv進行預處理
將圖片輸入網路進行識別

################################################

Tensorflow和MNIST簡介:
TensorFlow™ 是一個採用資料流圖,用於數值計算的開源軟體庫。它是一個不嚴格的“神經網路”庫,可以利用它提供的模組搭建大多數型別的神經網路。它可以基於CPU或GPU執行,可以自動使用GPU,無需編寫分配程式。主要支援Python編寫,但是官方說也有C++使用介面。MNIST是一個巨大的手寫數字資料集,被廣泛應用於機器學習識別領域。MNIST有60000張訓練集資料和10000張測試集資料,每一個訓練元素都是28*28畫素的手寫數字圖片。作為一個常見的資料集,MNIST經常被用來測試神經網路,也是比較基本的應用。

CNN卷積神經網路:
識別演算法主要使用的是卷積神經網路演算法(CNN)。 


主要結構為:輸入-卷積層-池化層-卷積層-池化層-全連線層-輸出。

卷積 
卷積其實可以看做是提取特徵的過程。如果不使用卷積的話,整個網路的輸入量就是整張圖片,處理就很困難。


假設圖中綠色5*5矩陣為原圖片,黃色的3*3矩陣就是我們的過濾器,即卷積核。將黃色矩陣和綠色矩陣被覆蓋的部分進行卷積計算,即每個元素相乘求和,便可得到這一部分的特徵值,即圖中的卷積特徵。 
然後,向右滑動黃色的矩陣,便可繼續求下一部分的卷積特徵值。而滑動的距離就是步長。


池化 
池化是用來把卷積結果進行壓縮,進一步減少全連線時的連線數。 


池化有兩種: 
一種是最大池化,在選中區域中找最大的值作為抽樣後的值; 
一種是平均值池化,把選中的區域中的平均值作為抽樣後的值。

#############################################################

一、訓練程式:

這裡我就先把程式貼出來,主體和tensorflow教程上大致相同。值得注意的是其中的saver部分,將訓練的權重和偏置儲存下來,在評價程式中可以再次使用。

# -*- coding:utf-8 -*-
# ==============================================================================
# 20171115
# HelloZEX
# 卷積神經網路 實現手寫數字識別
# 生成並儲存模型
# ==============================================================================

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_Labels_Images", one_hot=True)


import tensorflow as tf

sess = tf.InteractiveSession()


x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))


sess.run(tf.global_variables_initializer())

y = tf.matmul(x,W) + b

cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

for _ in range(1000):
  batch = mnist.train.next_batch(100)
  train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1,28,28,1])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

saver = tf.train.Saver()  # defaults to saving all variables

sess.run(tf.global_variables_initializer())
for i in range(20000):
  batch = mnist.train.next_batch(50)
  if i%100 == 0:
    train_accuracy = accuracy.eval(feed_dict={
        x:batch[0], y_: batch[1], keep_prob: 1.0})
    print("step %d, training accuracy %g"%(i, train_accuracy))

  train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 儲存模型引數,注意把這裡改為自己的路徑
saver.save(sess, 'CKPT/model.ckpt')

print("test accuracy %g"%accuracy.eval(feed_dict={
    x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

print("Finish!")

注意模型儲存位置!

##############################################

輸出結果:

/usr/bin/python2.7 /home/zhengxinxin/Desktop/PyCharm/Spark/SparkMNIST/SparkMNIST_TF1.py
Extracting MNIST_Labels_Images/train-images-idx3-ubyte.gz
Extracting MNIST_Labels_Images/train-labels-idx1-ubyte.gz
Extracting MNIST_Labels_Images/t10k-images-idx3-ubyte.gz
Extracting MNIST_Labels_Images/t10k-labels-idx1-ubyte.gz
2017-11-15 16:28:43.205071: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 16:28:43.205098: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 16:28:43.205103: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 16:28:43.205106: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 16:28:43.205109: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
0.9182
step 0, training accuracy 0.14
step 100, training accuracy 0.82
step 200, training accuracy 0.94
step 300, training accuracy 0.92
step 400, training accuracy 0.94
step 500, training accuracy 0.92
step 600, training accuracy 0.96
step 19600, training accuracy 1
step 19700, training accuracy 1
step 19800, training accuracy 1
step 19900, training accuracy 1
terminate called after throwing an instance of 'std::bad_alloc'
  what():  std::bad_alloc

Process finished with exit code 134 (interrupted by signal 6: SIGABRT)
########################################################
最後在CKPT資料夾中生成以下幾個檔案:


#########################################################

二、OpenCV處理手寫原圖:
下面我們就要對它進行預處理,縮小它的大小為28*28畫素,並轉變為灰度圖,進行二值化處理。我使用的是Opencv對影象進行處理,也可以使用MATLAB等進行預處理。 
圖片預處理程式如下:(程式改編自 參考5,可以使用滑鼠拖動選取框,對選取框中的影象進行處理)

#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <stdio.h>

using namespace cv;
using namespace std;

cv::Mat org, dst, img, tmp;
void on_mouse(int event, int x, int y, int flags, void *ustc)//event滑鼠事件代號,x,y滑鼠座標,flags拖拽和鍵盤操作的代號
{
	static Point pre_pt = cv::Point(-1, -1);//初始座標
	static Point cur_pt = cv::Point(-1, -1);//實時座標
	char temp[16];
	if (event == CV_EVENT_LBUTTONDOWN)//左鍵按下,讀取初始座標,並在影象上該點處劃圓
	{
		org.copyTo(img);//將原始圖片複製到img中
		sprintf(temp, "(%d,%d)", x, y);
		pre_pt = Point(x, y);
		putText(img, temp, pre_pt, FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0, 255), 1, 8);//在視窗上顯示座標
		circle(img, pre_pt, 2, Scalar(255, 0, 0, 0), CV_FILLED, CV_AA, 0);//劃圓
		imshow("img", img);
	}
	else if (event == CV_EVENT_MOUSEMOVE && !(flags & CV_EVENT_FLAG_LBUTTON))//左鍵沒有按下的情況下滑鼠移動的處理函式
	{
		img.copyTo(tmp);//將img複製到臨時影象tmp上,用於顯示實時座標
		sprintf(temp, "(%d,%d)", x, y);
		cur_pt = Point(x, y);
		putText(tmp, temp, cur_pt, FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0, 255));//只是實時顯示滑鼠移動的座標
		imshow("img", tmp);
	}
	else if (event == CV_EVENT_MOUSEMOVE && (flags & CV_EVENT_FLAG_LBUTTON))//左鍵按下時,滑鼠移動,則在影象上劃矩形
	{
		img.copyTo(tmp);
		sprintf(temp, "(%d,%d)", x, y);
		cur_pt = Point(x, y);
		putText(tmp, temp, cur_pt, FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0, 255));
		rectangle(tmp, pre_pt, cur_pt, Scalar(0, 255, 0, 0), 1, 8, 0);//在臨時影象上實時顯示滑鼠拖動時形成的矩形
		imshow("img", tmp);
	}
	else if (event == CV_EVENT_LBUTTONUP)//左鍵鬆開,將在影象上劃矩形
	{
		org.copyTo(img);
		sprintf(temp, "(%d,%d)", x, y);
		cur_pt = Point(x, y);
		putText(img, temp, cur_pt, FONT_HERSHEY_SIMPLEX, 0.5, Scalar(0, 0, 0, 255));
		circle(img, pre_pt, 2, Scalar(255, 0, 0, 0), CV_FILLED, CV_AA, 0);
		rectangle(img, pre_pt, cur_pt, Scalar(0, 255, 0, 0), 1, 8, 0);//根據初始點和結束點,將矩形畫到img上
		imshow("img", img);
		img.copyTo(tmp);
		//擷取矩形包圍的影象,並儲存到dst中
		int width = abs(pre_pt.x - cur_pt.x);
		int height = abs(pre_pt.y - cur_pt.y);
		if (width == 0 || height == 0)
		{
			printf("width == 0 || height == 0");
			return;
		}
		dst = org(Rect(min(cur_pt.x, pre_pt.x), min(cur_pt.y, pre_pt.y), width, height));
		cv::resize(dst, dst, Size(28, 28));
		cvtColor(dst, dst, CV_BGR2GRAY);
		threshold(dst, dst, 110, 255, CV_THRESH_BINARY);
		imwrite("T.png", dst);//注意將這裡改為自己的處理結果儲存地址
		namedWindow("dst");
		imshow("dst", dst);
		waitKey(0);
	}
}
int main()
{
	org = imread("7.jpg");//讀取圖片地址
	org.copyTo(img);
	org.copyTo(tmp);
	namedWindow("img");//定義一個img視窗
	setMouseCallback("img", on_mouse, 0);//呼叫回撥函式
	imshow("img", img);
	cv::waitKey(0);
}
需要注意根據你手寫圖片的條件修改二值化閾值。
threshold(dst, dst, 110, 255, CV_THRESH_BINARY);
處理後的結果:

這就是28*28的二值化後的圖片,這樣的格式和我們MNIST資料集中的圖片格式相同。只有這樣,我們才能將圖片輸入到網路中進行識別。
##########################################################
三、將圖片輸入到網路進行識別:
前向傳播的程式,最後softmax層分類的結果就是最後的識別結果啦。 

# -*- coding:utf-8 -*-
# ==============================================================================
# 20171115
# HelloZEX
# 卷積神經網路 實現手寫數字識別
# 讀取模型並運用識別手寫數字
# 如沒有cv2,可以嘗試 sudo pip install opencv-python
# ==============================================================================

from PIL import Image, ImageFilter
import tensorflow as tf
import matplotlib.pyplot as plt
#import cv2

def imageprepare():
    """
    This function returns the pixel values.
    The imput is a png file location.
    """
    file_name='Pictures/7.png'#匯入自己的圖片地址
    #in terminal 'mogrify -format png *.jpg' convert jpg to png
    im = Image.open(file_name).convert('L')

    im.save("Pictures/sample.png")
    plt.imshow(im)
    plt.show()
    tv = list(im.getdata()) #get pixel values

    #normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.
    tva = [ (255-x)*1.0/255.0 for x in tv]
    #print(tva)
    return tva

    """
    This function returns the predicted integer.
    The imput is the pixel values from the imageprepare() function.
    """

    # Define the model (same as when creating the model file)
result=imageprepare()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))

def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)

def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)

def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])

x_image = tf.reshape(x, [-1,28,28,1])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)

init_op = tf.initialize_all_variables()


"""
Load the model2.ckpt file
file is stored in the same directory as this python script is started
Use the model to predict the integer. Integer is returend as list.

Based on the documentatoin at
https://www.tensorflow.org/versions/master/how_tos/variables/index.html
"""
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    saver.restore(sess, "CKPT/model.ckpt")#這裡使用了之前儲存的模型引數
    #print ("Model restored.")

    prediction=tf.argmax(y_conv,1)
    predint=prediction.eval(feed_dict={x: [result],keep_prob: 1.0}, session=sess)
    print(h_conv2)

    print('recognize result:')
    print(predint[0])

執行中產生一個Figure1,叉掉他就繼續執行。


輸出結果:

/usr/bin/python2.7 /home/zhengxinxin/Desktop/PyCharm/Spark/SparkMNIST/SparkMNIST_TF2.py
WARNING:tensorflow:From /home/zhengxinxin/Desktop/PyCharm/Spark/SparkMNIST/SparkMNIST_TF2.py:85: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
Instructions for updating:
Use `tf.global_variables_initializer` instead.
2017-11-15 19:09:12.008792: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.1 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 19:09:12.008817: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use SSE4.2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 19:09:12.008822: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 19:09:12.008825: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use AVX2 instructions, but these are available on your machine and could speed up CPU computations.
2017-11-15 19:09:12.008829: W tensorflow/core/platform/cpu_feature_guard.cc:45] The TensorFlow library wasn't compiled to use FMA instructions, but these are available on your machine and could speed up CPU computations.
Tensor("Relu_1:0", shape=(?, 14, 14, 64), dtype=float32)
recognize result:
7

Process finished with exit code 0

可以看到正確識別了手寫數字。可喜可樂可喜可樂!!!!