深度學習之mnist手寫數字識別入門
阿新 • • 發佈:2019-02-11
使用tensorflow框架和python,學習實現簡單的神經網路,並進行調參,程式碼如下:
#! /usr/bin/python # -*- coding:utf-8 -*- """ a simple mnist classifier """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys import numpy as np from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf FLAGS = None #讀取資料 data_dir = './minst_data' mnist = input_data.read_data_sets(data_dir,one_hot = True) #神經元數量 h1_nodes = 1400 h2_nodes = 1400 #x佔位符 x = tf.placeholder(tf.float32,[None,784]) #權重,偏置初始化 第一層 W1 = tf.Variable(tf.cast(np.random.randn(784,h1_nodes),tf.float32)*np.sqrt(2.0/784)) b1 = tf.Variable(tf.zeros([h1_nodes])) #權重,偏置初始化 第二層 W2 = tf.Variable(tf.cast(np.random.randn(h1_nodes,h2_nodes),tf.float32)*np.sqrt(2.0/h1_nodes)) b2 = tf.Variable(tf.zeros([h2_nodes])) #權重,偏置初始化 輸出層 W3 = tf.Variable(tf.zeros([h2_nodes, 10])) b3 = tf.Variable(tf.zeros([10])) #啟用函式 h1 = tf.nn.relu(tf.matmul(x,W1) + b1) h2 = tf.nn.relu(tf.matmul(h1,W2) + b2) #未啟用的logits y = tf.matmul(h2,W3) + b3 #ground truth 佔位符 y_ = tf.placeholder(tf.float32,[None,10]) #交叉熵 cross_entroy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(label = y_,logits = y)) #正則項 #係數 Lambda = 0.0004 #l2正則 regularizer = tf.contrib.layers.l2_regularizer(Lambda) regularization = regularizer(W1) + regularizer(W2) + regularizer(W13) #loss loss = cross_entroy + regularization #訓練step learning_rate = 1 training_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss) #session sess = tf.Session() #初始化 init_op = tf.global_variables_initializer() #run sess.run(init_op) #batch_size batch_size = 100 #batchs batchs = mnist.train.num_examples//batch_size #正確預測 correct_prediction = tf.reduce_mean(tf.argmax(y,1),tf.argmax(y_,1)) #正確率 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) #training enpoch_num = 20 for epoch in range(enpoch_num): for batch in range(batchs): batch_xs,batch_ys = mnist.train.next_batch(batch_size) sess.run(training_step,feed_dict = {x:batch_xs,y_:batch_ys}) loss_out = sess.run(loss,feed_dict = {x:batch_xs,y_:batch_ys}) train_accuracy = sess.run(accuracy,feed_dict = {x:mnist.train.images,y_:mnist.train.labels}) test_accuracy = sess.run(accuracy,feed_dict = {x:mnist.train.images,y_:mnist.train.labels}) print("epoch"+str(epoch)+"--train_accuracy:"+str(train_accuracy)+"--test_accuracy:"+str(test_accuracy)+"--loss:"+str(loss))