1. 程式人生 > >CNN實現mnist資料集數字識別

CNN實現mnist資料集數字識別

# -*- coding: utf-8 -*-


# @Time    : 2018/12/14 13:07
# @Author  : WenZhao
# @Email   : [email protected]
# @File    : mnistCnn-1.py
# @Software: PyCharm

'''
    CNN實現mnist資料集數字識別
卷積神經網路
    1.卷積層      : conv2d
    2.非線性變換層: tf.nn.relu/sigmiod/tanh(啟用函式)
    3.池化層      : tf.nn.pool/tf.nn.avg
    4.全連線層    : w*x+b
'''

import tensorflow as tf

# 下載資料集
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("./data/MNIST_data/",one_hot=True)


x=tf.placeholder("float",shape=[None,784])
y_=tf.placeholder("float",shape=[None,10])

x_image=tf.reshape(x,[-1,28,28,1])

# tf.contrib.layers.convolution2d完成了卷積和啟用兩步
conv2d_1=tf.contrib.layers.convolution2d(
    x_image,
    num_outputs=32,
    kernel_size=(5,5,),
    activation_fn=tf.nn.relu,
    stride=(1,1),
    padding='SAME',
    trainable=True
)

# 池化層
pool_1=tf.nn.max_pool(conv2d_1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

conv2d_2=tf.contrib.layers.convolution2d(
    pool_1,
    num_outputs=64,
    kernel_size=(5,5,),
    activation_fn=tf.nn.relu,
    stride=(1,1),
    padding='SAME',
    trainable=True
)

pool_2=tf.nn.max_pool(conv2d_2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')


# 扁平化
pool2_flat=tf.reshape(pool_2,[-1,7*7*64])

# 全連線運算
fc_1=tf.contrib.layers.fully_connected(pool2_flat,1024,activation_fn=tf.nn.relu)

# dropout層:隨機去掉一些單元,增加擬合性
keep_prob=tf.placeholder("float")
fc1_drop=tf.nn.dropout(fc_1,keep_prob)


fc_2=tf.contrib.layers.fully_connected(fc1_drop,10,activation_fn=tf.nn.softmax)

loss=-tf.reduce_sum(y_*tf.log(fc_2))
train_step=tf.train.GradientDescentOptimizer(0.0001).minimize(loss)

sess=tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(20000):
    batch=mnist.train.next_batch(50)
    sess.run(train_step,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})
    if i%100==0:
        print(sess.run(loss,feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5}))

# 計算準確率
correct_prediction=tf.equal(tf.argmax(fc_2,1),tf.argmax(y_,1))

accuracy=tf.reduce_mean(tf.cast(correct_prediction,"float"))

# 測試集準確率
acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels,keep_prob:1})

print(acc)