tensorflow實現CNN識別手寫數字
上一篇使用TensorFlow識別手寫數字,是直接採用了softmax進行多分類,直接將28*28的圖片轉換成為了784維的向量作為輸入,然後通過一個(784,10)的權重,將輸入轉換成一個10維的向量,最後再將對每一個數字預測一個概率,概率最大的數字就是預測的結果。因為,直接將圖片轉成一個784維的向量,丟棄了圖片原有的結構資訊,但是最後對於測試集準確率還是可以達到91%。這一篇,介紹通過CNN來實現手寫數字的識別,準確率可以達到98%。
一、CNN(卷積神經網路 convolutional neural network)
今年,可以說是人工智慧被炒的最火的一年了。以致於python的使用快趕上java,人工智慧能夠這麼火,當然也離不開CNN。CNN的使用範圍也很廣在語音識別、自然語言處理、影象處理都能看見它的身影。剛剛開始聽見卷積的時候,給我一種深不可測的感覺。在網上也看了很多關於卷積的文章,對於卷積的介紹也是非常詳細的,這裡我推薦三篇文章對於卷積的介紹寫的還是非常不錯的(PS是英文的)
下面,簡單的畫一個圖,如何通過卷積來達到目標。當,我們在訓練一個卷積神經網路的時候,需要大量的資料、設定卷積的層數、卷積核大小、池化的方式(最大、平均)、損失函式、設定目標(如:手寫數字的識別,輸出一個10維向量),然後讓卷積神經網路通過不斷的訓練更新引數向我們設定的目標靠近,最後我們可以通過這些引數來預測樣本。至於,卷積的工作方式看上面的推薦的文章,有非常詳細的介紹。可能你會覺得這個東西有點抽象,網上有大牛將卷積的過程視覺化,具體的可以參考它的git專案
二、卷積神經網路的結構圖
上面的流程圖,是整個卷積網路的一個結構圖。我們在使用TensorFlow實現這個結構的時候,其實還是非常簡單的,只需要設定卷積核的大小,這裡設計的是5*5,邊距的填充方式,卷積的個數、啟用函式、池化的方式、輸出類別的個數,在最後我會給出使用TensorFlow實現整個結構的程式碼。下面我會對這個結構提幾個問題並解答:
1、為什麼卷積核的大小要設定成5*5,需要32個卷積?
卷積核的大小其實你可以自己隨便設定,如:3*3、5*5、7*7、9*9等,一般都為奇數,卷積的個數也是自己設定的,32個卷積的意思,代表的是你要提取原圖上32個特徵(每一個卷積提取一種特徵)。
2、為什麼要使用RELU啟用函式?
設定啟用函式的目的是保證結果輸出的非線性化,RELU啟用函式需要大於一個閾值,才會有輸出,和人的神經元結構很像,啟用函式的種類有很多,RELU的變種就有很多,在卷積神經網路中經常使用的啟用函式有RELU和tanh。
3、為什麼通過一個5*5的卷積和池化之後,原圖28*28的影象就變成了14*14?
輸入28*28的影象通過5*5的卷積之後,輸出還是28*28,這和卷積的方式有關,設定步長為1,如果對28*28的影象設定不填充邊距,那麼輸出影象的大小應該是(28-5)/1 + 1,輸出影象應該是24*24,如果我們將原圖的填充邊距設定為2(在原圖的周圍填充兩圈全0),來保證輸入影象和輸出影象的大小一致,這個時候的計算公式(28-5+2*2)/1 + 1,輸出影象的大小還是和原圖保持一致。這樣做的目的,是為了防止輸入影象經過卷積之後過快的衰減,因為有時候我們設計的卷積網路層數可能達到上100層,而填充0並不會對結果有影響。常見池化的方式有兩種,均值和最大值,池化核的大小設定為2*2,代表是從2*2中選出一個值(平均值或者最大值),所以一個28*28的影象再經過2*2的池化之後就變成了14*14,池化的目的是為了減少引數而且還可以很好的保證影象的特徵。
三、實現程式碼
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
#初始化權重函式
def weight_variable(shape):
initial = tf.truncated_normal(shape,stddev=0.1);
return tf.Variable(initial)
#初始化偏置項
def bias_variable(shape):
initial = tf.constant(0.1,shape=shape)
return tf.Variable(initial)
#定義卷積函式
def conv2d(x,w):
return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
#定義一個2*2的最大池化層
def max_pool_2_2(x):
return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
if __name__ == "__main__":
#定義輸入變數
x = tf.placeholder("float",shape=[None,784])
#定義輸出變數
y_ = tf.placeholder("float",shape=[None,10])
#初始化權重,第一層卷積,32的意思代表的是輸出32個通道
# 其實,也就是設定32個卷積,每一個卷積都會對影象進行卷積操作
w_conv1 = weight_variable([5,5,1,32])
#初始化偏置項
b_conv1 = bias_variable([32])
#將輸入的x轉成一個4D向量,第2、3維對應圖片的寬高,最後一維代表圖片的顏色通道數
# 輸入的影象為灰度圖,所以通道數為1,如果是RGB圖,通道數為3
# tf.reshape(x,[-1,28,28,1])的意思是將x自動轉換成28*28*1的陣列
# -1的意思是代表不知道x的shape,它會按照後面的設定進行轉換
x_image = tf.reshape(x,[-1,28,28,1])
# 卷積並激活
h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1) + b_conv1)
#池化
h_pool1 = max_pool_2_2(h_conv1)
#第二層卷積
#初始權重
w_conv2 = weight_variable([5,5,32,64])
#初始化偏置項
b_conv2 = bias_variable([64])
#將第一層卷積池化後的結果作為第二層卷積的輸入
h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2) + b_conv2)
#池化
h_pool2 = max_pool_2_2(h_conv2)
# 設定全連線層的權重
w_fc1 = weight_variable([7*7*64,1024])
# 設定全連線層的偏置
b_fc1 = bias_variable([1024])
# 將第二層卷積池化後的結果,轉成一個7*7*64的陣列
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
# 通過全連線之後並激活
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1) + b_fc1)
# 防止過擬合
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)
#輸出層
w_fc2 = weight_variable([1024,10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2) + b_fc2)
#日誌輸出,每迭代100次輸出一次日誌
#定義交叉熵為損失函式
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
#最小化交叉熵
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#計算準確率
correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
sess = tf.Session()
sess.run(tf.initialize_all_variables())
# 下載minist的手寫數字的資料集
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
for i in range(20000):
batch = mnist.train.next_batch(50)
if i % 100 == 0:
train_accuracy = accuracy.eval(session=sess,feed_dict={x:batch[0],y_:batch[1],keep_prob:1.0})
print("step %d,training accuracy %g"%(i,train_accuracy))
train_step.run(session = sess,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
print("test accuracy %g" % accuracy.eval(session=sess,feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
#test accuracy 0.9919