1. 程式人生 > >深度學習之mnist手寫數字識別入門

深度學習之mnist手寫數字識別入門

使用tensorflow框架和python,學習實現簡單的神經網路,並進行調參,程式碼如下:
 

#! /usr/bin/python
# -*- coding:utf-8 -*-

"""
a simple mnist classifier

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

import tensorflow as tf

FLAGS = None
#讀取資料
data_dir = './minst_data'
mnist = input_data.read_data_sets(data_dir,one_hot = True)
#神經元數量
h1_nodes = 1400
h2_nodes = 1400
#x佔位符
x = tf.placeholder(tf.float32,[None,784])
#權重,偏置初始化  第一層
W1 = tf.Variable(tf.cast(np.random.randn(784,h1_nodes),tf.float32)*np.sqrt(2.0/784))
b1 = tf.Variable(tf.zeros([h1_nodes]))
#權重,偏置初始化  第二層
W2 = tf.Variable(tf.cast(np.random.randn(h1_nodes,h2_nodes),tf.float32)*np.sqrt(2.0/h1_nodes))
b2 = tf.Variable(tf.zeros([h2_nodes]))
#權重,偏置初始化  輸出層
W3 = tf.Variable(tf.zeros([h2_nodes, 10])) 
b3 = tf.Variable(tf.zeros([10]))
#啟用函式
h1 = tf.nn.relu(tf.matmul(x,W1) + b1)
h2 = tf.nn.relu(tf.matmul(h1,W2) + b2)
#未啟用的logits
y = tf.matmul(h2,W3) + b3

#ground truth 佔位符
y_ = tf.placeholder(tf.float32,[None,10])

#交叉熵
cross_entroy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(label = y_,logits = y))

#正則項
#係數
Lambda = 0.0004
#l2正則
regularizer = tf.contrib.layers.l2_regularizer(Lambda)
regularization = regularizer(W1) + regularizer(W2) + regularizer(W13)
#loss
loss = cross_entroy + regularization

#訓練step
learning_rate = 1
training_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)

#session
sess = tf.Session()
#初始化
init_op = tf.global_variables_initializer()
#run
sess.run(init_op)

#batch_size
batch_size = 100
#batchs
batchs = mnist.train.num_examples//batch_size
#正確預測
correct_prediction = tf.reduce_mean(tf.argmax(y,1),tf.argmax(y_,1))
#正確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#training
enpoch_num = 20
for epoch in range(enpoch_num):
  for batch in range(batchs):
    batch_xs,batch_ys = mnist.train.next_batch(batch_size)
	sess.run(training_step,feed_dict = {x:batch_xs,y_:batch_ys})
  loss_out = sess.run(loss,feed_dict = {x:batch_xs,y_:batch_ys})
  train_accuracy = sess.run(accuracy,feed_dict = {x:mnist.train.images,y_:mnist.train.labels})
  test_accuracy = sess.run(accuracy,feed_dict = {x:mnist.train.images,y_:mnist.train.labels})  
  print("epoch"+str(epoch)+"--train_accuracy:"+str(train_accuracy)+"--test_accuracy:"+str(test_accuracy)+"--loss:"+str(loss))