TensroFlow學習——第三章(一)
阿新 • • 發佈:2018-12-05
MINIST數字識別問題
全連線層實現手寫數字識別
採用了L2正則化、滑動平均模型和指數衰減學習率
訓練結果為:訓練集93%,驗證集95.36%,測試集95.01%
第一部分:前向傳播和網路引數
# 定義前向傳播和神經網路中的引數 import tensorflow as tf # 配置神經網路引數 INPUT_NODE=784 # 輸入層節點個數 OUTPUT_NODE=10 # 輸出層節點個數 LAYER1_NODE=500 # 隱層節點個數 def get_weight_variable(shape,regularizer): weights=tf.get_variable('weights',shape,initializer=tf.truncated_normal_initializer(mean=0,stddev=0.1),) # 正則化 if regularizer!=None: tf.add_to_collection('losses',regularizer(weights)) return weights # 前向傳播 def inference(input_tensor,regularizer,avg_class,reuse): # 宣告第一層神經網路的變數並完成前向傳播 with tf.variable_scope('layer1',reuse=reuse): weights=get_weight_variable([INPUT_NODE,LAYER1_NODE],regularizer) biases=tf.get_variable('biases',[LAYER1_NODE],initializer=tf.constant_initializer(0.0)) if avg_class == None: layer1=tf.nn.relu(tf.matmul(input_tensor,weights)+biases) else: layer1=tf.nn.relu(tf.matmul(input_tensor,avg_class.average(weights))+avg_class.average(biases)) # 宣告第二層神經網路的變數並完成前向傳播 with tf.variable_scope('layer2',reuse=reuse): weights=get_weight_variable([LAYER1_NODE,OUTPUT_NODE],regularizer) biases=tf.get_variable('biases',[OUTPUT_NODE],initializer=tf.constant_initializer(0.0)) if avg_class == None: layer2=tf.matmul(layer1,weights)+biases else: layer2=tf.matmul(layer1,avg_class.average(weights))+avg_class.average(biases) return layer2
第二部分:訓練,包括訓練集和驗證集
# 神經網路訓練程式 import os import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data import mnist_inference # 配置神經網路引數 BATCH_SIZE=100 LEARNING_RATE_BASE=0.8 LEARNING_RATE_DECAY=0.99 REGULARAZTION_RATE=0.0001 TRAINING_STEP=200 MOVING_AVERAGE_DECAY=0.99 # 模型儲存路徑和檔名 MODEL_SAVE_PATH='./' MODEL_NAME='model.ckpt' # 訓練引數 train_acc,valid_acc=[],[] train_loss,valid_loss=[],[] epochs=[] def train(mnist): x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input') y_=tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input') regularizer=tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE) y=mnist_inference.inference(x,regularizer=regularizer,avg_class=None,reuse=False) global_step=tf.Variable(0,trainable=False) variable_averages=tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step) variable_averages_op=variable_averages.apply(tf.trainable_variables()) cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1)) cross_entropy_mean=tf.reduce_mean(cross_entropy) loss=cross_entropy_mean+tf.add_n(tf.get_collection('losses')) learning_rate=tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY) train_step=tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step) with tf.control_dependencies([train_step,variable_averages_op]): train_op=tf.no_op(name='train') # 計算使用滑動平均之後的前向傳播結果 average_y=mnist_inference.inference(x,regularizer=regularizer,avg_class=variable_averages,reuse=tf.AUTO_REUSE) correct_prediction=tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1)) #tf.cast為轉化資料格式 accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) # 初始化TensorFlow持久化類 saver=tf.train.Saver() # 載入測試資料 validate_feed={x:mnist.validation.images,y_:mnist.validation.labels} with tf.Session() as sess: tf.initialize_all_variables().run() for i in range(TRAINING_STEP): xs,ys=mnist.train.next_batch(BATCH_SIZE) _,tra_loss,step=sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys}) val_loss=sess.run([loss],feed_dict=validate_feed) epochs.append(step) train_acc.append(sess.run(accuracy,feed_dict={x:xs,y_:ys})) train_loss.append(tra_loss) valid_acc.append(sess.run(accuracy,feed_dict=validate_feed)) valid_loss.append(val_loss) # 每100輪 if (i+1)%10==0: print('<==%d==>,loss on training batch is %g.'%(i+1,tra_loss)) print(train_acc[-1]) print(valid_acc[-1]) plt.figure(1) plt.grid(True) plt.subplot(1,2,1) plt.plot(epochs, train_loss, color='red',label='train') plt.plot(epochs, valid_loss, color='blue',label='valid') plt.legend() plt.xlabel('Epochs',fontsize=15) plt.ylabel('Y',fontsize=15) plt.title('Loss',fontsize=15) plt.subplot(1,2,2) plt.plot(epochs, train_acc, color='red',label='train') plt.plot(epochs, valid_acc, color='blue',label='valid') plt.legend() plt.xlabel('Epochs',fontsize=15) plt.ylabel('Y',fontsize=15) plt.title('Acc',fontsize=15) plt.show() saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME)) def main(argv=None): mnist=input_data.read_data_sets('E:/User-Duanduan/python/Deep-Learning/tensorflow/data/MNIST_data/',one_hot=True) train(mnist) if __name__=='__main__': tf.app.run()
第三部分:測試集
# 測試模型 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import matplotlib.pyplot as plt import mnist_inference import mnist_train def evaluate(mnist): with tf.Graph().as_default() as g: # 定義輸入輸出格式 x=tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input') y_=tf.placeholder(tf.float32,[None,mnist_inference.OUTPUT_NODE],name='y-input') test_feed={x:mnist.test.images,y_:mnist.test.labels} show_image=mnist.test.images[1000] label=mnist.test.labels[1000] one_image={x:[show_image],y_:[label]} result_image=label.tolist().index(max(label.tolist())) y=mnist_inference.inference(x,None,None,reuse=False) one_result=tf.argmax(y,1) correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) variable_averages=tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY) variable_to_restore=variable_averages.variables_to_restore() saver=tf.train.Saver(variable_to_restore) with tf.Session() as sess: # 載入模型 saver.restore(sess,'./model.ckpt') accuracy_score=sess.run(accuracy,feed_dict=test_feed) print('Test accuracy is %g%%'%(accuracy_score*100)) result=sess.run(one_result,feed_dict=one_image) print('Actual:%g,predtion:%g'%(result_image,result)) show_image = tf.reshape(show_image, [28, 28]) plt.figure('Show') plt.imshow(show_image.eval()) plt.show() def main(argv=None): mnist=input_data.read_data_sets('E:/User-Duanduan/python/Deep-Learning/tensorflow/data/MNIST_data/',one_hot=True) evaluate(mnist) if __name__=='__main__': tf.app.run()