1. 程式人生 > >基於tensorflow的MNIST手寫數字識別(三)--神經網路篇

基於tensorflow的MNIST手寫數字識別(三)--神經網路篇

想想還是要說點什麼

    抱歉啊,第三篇姍姍來遲,確實是因為我懶,而不是忙什麼的,所以這次再加點料,以表示我的歉意。廢話不多說,我就直接開始講了。

加入神經網路的意義

  •     前面也講到了,使用普通的訓練方法,也可以進行識別,但是識別的精度不夠高,因此我們需要對其進行提升,其實MNIST官方提供了很多的組合方法以及測試精度,並做成了表格供我們選用,谷歌官方為了保證教學的簡單性,所以用了最簡單的卷積神經網路來提升這個的識別精度,原理是通過強化它的特徵(比如輪廓等),其實我也剛學,所以能看懂就說明它確實比較簡單。

    •     我的程式碼都是在0.7版本的tensorflow上實現的,建議看一下前兩篇文章先。

流程和步驟

    其實流程跟前面的差不多,只是在softmax前進行了卷積神經網路的操作,所也就不仔細提出了,這裡只說卷積神經網路的部分。
如第一篇文章所說,我們的卷積神經網路的,過程是卷積->池化->全連線.

# 卷積函式
# convolution
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
#這裡tensorflow自己帶了conv2d函式做卷積,然而我們自定義了個函式,用於指定步長為1,邊緣處理為直接複製過來



# pooling
def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu=None, name=None)

Computes a 2-D convolution given 4-D input and filter tensors.

Given an input tensor of shape [batch, in_height, in_width, in_channels] and a filter / kernel tensor of shape [filter_height, filter_width, in_channels, out_channels], this op performs the following:

Flattens the filter to a 2-D matrix with shape [filter_height * filter_width * in_channels, output_channels].

Extracts image patches from the the input tensor to form a virtual tensor of shape [batch, out_height, out_width, filter_height * filter_width * in_channels].

For each patch, right-multiplies the filter matrix and the image patch vector.
In detail,

output[b, i, j, k] =
sum_{di, dj, q} input[b, strides[1] * i + di, strides[2] * j + dj, q] *
filter[di, dj, q, k]

Must have strides[0] = strides[3] = 1. For the most common case of the same horizontal and vertices strides, strides = [1, stride, stride, 1].

Args:

input: A Tensor. Must be one of the following types: float32, float64.

filter: A Tensor. Must have the same type as input.

strides: A list of ints. 1-D of length 4. The stride of the sliding window for each dimension of input.

padding: A string from: “SAME”, “VALID”. The type of padding algorithm to use.

use_cudnn_on_gpu: An optional bool. Defaults to True.

name: A name for the operation (optional).

Returns:

A Tensor. Has the same type as input.

tf.nn.max_pool(value, ksize, strides, padding, name=None)

Performs the max pooling on the input.

Args:

value: A 4-D Tensor with shape [batch, height, width, channels] and type float32, float64, qint8, quint8, qint32.

ksize: A list of ints that has length >= 4. The size of the window for each dimension of the input tensor.

strides: A list of ints that has length >= 4. The stride of the sliding window for each dimension of the input tensor.

padding: A string, either ‘VALID’ or ‘SAME’. The padding algorithm.

name: Optional name for the operation.

Returns:

A Tensor with the same type as value. The max pooled output tensor.

初始化權重和偏置值矩陣,值是空的,需要後期訓練。

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape = shape)
    # print(tf.Variable(initial).eval())
    return tf.Variable(initial)
#這是做了兩次卷積和池化
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
這裡是做了全連線,還用了relu啟用函式(RELU在下面會提到)
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
#為了防止過擬合化,這裡用dropout來關閉一些連線(DROP下面會提到)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

然後得到的結果再跟之前的一樣,使用softmax等方法訓練即可得到引數。

RELU啟用函式

啟用函式有很多種,最常用的是以下三種

Sigmoid

將資料對映到0-1範圍內
#### 公式如下
這裡寫圖片描述

####函式影象如下
函式影象

Tanh

將資料對映到-1-1的範圍內

公式如下

這裡寫圖片描述

函式影象如下
這裡寫圖片描述

RELU

小於0的值就變成0,大於0的等於它本身

函式影象

這裡寫圖片描述

dropout的作用

  • 以前學習數學我們常用到一種方法,叫做待定係數法,就是給定2次函式上的幾個點,然後求得2次函式的引數。

  • 一樣的道理,我們這裡用格式訓練集訓練,最後訓練得到引數,其實就是在求得一個模型(函式),使得它能跟原始資料的曲線進行擬合(說白了,就是假裝原始資料都在我們計算出來的函式上)

  • 但是這樣不行啊,因為我們還需要對未知資料進行預測啊,如果原始的資料點都在(或者大多數都在)函式上了(這就是過擬合),那會被很多訓練資料誤導的,所以其實只要一個大致的趨勢函式就可以了

  • 所以Dropout函式就是用來,減少某些點的全連線(可以理解為把一些點去掉了),來防止過擬合

程式碼