1. 程式人生 > >Tensorflow編程基礎之Mnist手寫識別實驗+關於cross_entropy的理解

Tensorflow編程基礎之Mnist手寫識別實驗+關於cross_entropy的理解

ast 大學時光 default ice red con graph cast utf-8

好久沒有靜下心來寫點東西了,最近好像又回到了高中時候的狀態,休息不好,無法全心學習,惡性循環,現在終於調整的好一點了,聽著純音樂突然非常傷感,那些曾經快樂的大學時光啊,突然又慢慢的一下子出現在了眼前,不知道我大學的那些小夥伴們現在都怎麽樣了,考研的剛剛希望他考上,實習的菜頭希望他早日脫離苦海,小瑞哥希望他早日出成果,範爺熊健研究生一定要過的開心啊!天哥也哥早日結婚領證!那些回不去的曾經的快樂的時光,你們都還好嗎!

最近開始接觸Tensorflow,可能是論文裏用的是這個框架吧,其實我還是覺得pytorch更方便好用一些,仔細讀了最簡單的Mnist手寫識別程序,覺得大同小異,關鍵要理解Tensorflow的思想,文末就寫一下自己看交叉熵的感悟,絮叨了這麽多開始寫點代碼吧! 1

# -*- coding: utf-8 -*-

  2 """
  3 Created on Sun Nov 11 16:14:38 2018
  4 
  5 @author: Yang
  6 """
  7 
  8 import tensorflow as tf 
  9 from tensorflow.examples.tutorials.mnist import input_data 
 10 
 11 mnist = input_data.read_data_sets("/MNIST_data",one_hot=True) #從input_data中讀取數據集,使用one_hot編碼
12 13 import pylab #畫圖模塊 14 15 tf.reset_default_graph()#重置一下圖 圖代表了一個運算過程,包含了許多Variable和op,如果不重置一下圖的話,可能會因為某些工具重復調用變量而報錯 16 17 x = tf.placeholder(tf.float32,[None,784])#占位符,方便用feed_dict進行註入操作 18 y = tf.placeholder(tf.float32,[None,10])#占位符,方便用feed_dict進行註入操作
20 21 W = tf.Variable(tf.random_normal([784,10]))#要學習的參數統一用Variable來定義,這樣方便進行調整更新
22 b = tf.Variable(tf.zeros([10])) 23 24 25 #construct the model 26 pred = tf.nn.softmax(tf.matmul(x,W) + b) 27 28 cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) 29 30 learning_rate = 0.01 31 32 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) 33 34 #set parameters about thee model 35 training_epoch = 25 36 batch_size = 100 37 display_step = 1 38 saver = tf.train.Saver() 39 model_path = "log/kerwinsmodel.ckpt" 40 41 #start the session 42 43 with tf.Session() as sess : 44 sess.run(tf.global_variables_initializer()) 45 46 for epoch in range(training_epoch): 47 avg_cost = 0 48 total_batch = int(mnist.train.num_examples/batch_size) 49 print(total_batch) 50 for i in range(total_batch): 51 batch_xs,batch_ys = mnist.train.next_batch(batch_size) 52 53 _,c = sess.run([optimizer,cost],feed_dict={x:batch_xs,y:batch_ys}) 54 55 avg_cost += c/ total_batch 56 if (epoch +1 ) % display_step ==0: 57 print("Epoch:",%04d %(epoch+1),"cost=","{:.9f}".format(avg_cost)) 58 59 print("Finish!") 60 61 correct_prediction = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) 62 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 63 print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels})) 64 65 save_path = saver.save(sess,model_path) 66 print("Model saved in file: %s" % save_path) 67 # 68 69 70 #讀取模型程序 71 72 print("Starting 2nd session...") 73 with tf.Session() as sess: 74 sess.run(tf.global_variables_initializer()) 75 saver.restore(sess,model_path) 76 77 #測試model 78 correct_prediction = tf.equal(tf.arg_max(pred,1),tf.argmax(y,1)) 79 #計算準確率 80 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 81 print("Accuracy:",accuracy.eval({x:mnist.test.images,y:mnist.test.labels})) 82 83 output = tf.argmax(pred,1) 84 batch_xs,batch_ys = mnist.train.next_batch(2) 85 outputval,predv = sess.run([output,pred],feed_dict={x:batch_xs}) 86 print(outputval,predv,batch_ys) 87 88 im = batch_xs[0] 89 im = im.reshape(-1,28) 90 pylab.imshow(im) 91 pylab.show() 92 93 im = batch_xs[1] 94 im = im.reshape(-1,28) 95 pylab.imshow(im) 96 pylab.show() 97 98 99 100

#占位符,方便用feed_dict進行註入操作

Tensorflow編程基礎之Mnist手寫識別實驗+關於cross_entropy的理解