1. 程式人生 > >Tensorflow+MNIST+CNN+模型儲存與讀取

Tensorflow+MNIST+CNN+模型儲存與讀取

# coding: utf-8
import tensorflow as tf
import numpy as np
from utils import *
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) #以one-hot編碼讀取mnist資料集
num_steps = 10000  #訓練迭代步數
class mnistmodel(object):
    def __init__(self):
        self.
_build_model() def _build_model(self): self.images = tf.placeholder(tf.float32, [None,784]) #設定圖片佔位符 self.labels = tf.placeholder(tf.float32,[None,10]) #設定標籤佔位符 with tf.variable_scope('feature_extractor'):#特徵提取部分(包含兩個卷積層) self.processimages = tf.reshape(self.
images,[-1,28,28,1]) #將輸入圖片reshape成[28,28,1]形狀 #網路第一層 W_conv0 = weight_variable([5,5,1,32]) #該層有325*5卷積核 b_conv0 = bias_variable([32]) #32個bias h_conv0 = tf.nn.relu(conv2d(self.processimages, W_conv0) + b_conv0) #卷積操作,使用relu啟用函式 h_pool0 = max_pool_2x2
(h_conv0) #max pooling操作 #網路第二層,與第一層類似 W_conv1 = weight_variable([5,5,32,48]) b_conv1 = bias_variable([48]) h_conv1 = tf.nn.relu(conv2d(h_pool0,W_conv1)+b_conv1) h_pool1 = max_pool_2x2(h_conv1) #將第二層輸出reshape為二維矩陣以便輸入全連線層 self.feature = tf.reshape(h_pool1, [-1, 7 * 7 * 48]) with tf.variable_scope('label_predictor'):#標籤預測部分(兩層全連線層) #從7*7*48對映到100 W_fc0 = weight_variable([7*7*48,100]) b_fc0 = bias_variable([100]) h_fc0 = tf.nn.relu(tf.matmul(self.feature,W_fc0) + b_fc0) #從100對映到10,以便之後分類操作 W_fc1 = weight_variable([100, 10]) b_fc1 = bias_variable([10]) logits = tf.matmul(h_fc0,W_fc1) + b_fc1 self.pred = tf.nn.softmax(logits)#使用Softmax將連續數值轉化成相對概率 #使用交叉熵做標籤預測損失 self.pred_loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=self.labels) graph = tf.get_default_graph() with graph.as_default(): model = mnistmodel() learning_rate = tf.placeholder(tf.float32,[]) pred_loss = tf.reduce_mean(model.pred_loss) #隨機梯度下降對loss進行優化 train_op = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(pred_loss) # 計算標籤預測準確率 correct_label_pred = tf.equal(tf.argmax(model.labels, 1), tf.argmax(model.pred, 1)) label_acc = tf.reduce_mean(tf.cast(correct_label_pred, tf.float32)) with tf.Session(graph= graph) as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(max_to_keep=1)#建立saver物件來儲存訓練的模型 max_acc = 0 is_train = True # training loop if is_train: for i in range(num_steps): lr = 0.001 #呼叫mnist自帶的next_batch函式生成大小為100的batch batch = mnist.train.next_batch(100) _,p_loss,l_acc = sess.run([train_op, pred_loss, label_acc], feed_dict={model.images: batch[0],model.labels: batch[1],learning_rate:lr}) print('step:{} pred_loss:{} l_acc: {}'.format(i,p_loss,l_acc)) if i%100==0 : test_acc = sess.run(label_acc,feed_dict={model.images:mnist.test.images, model.labels:mnist.test.labels}) print('step: {} test_acc: {}'.format(i,test_acc)) #計算當前模型在測試集上準確率,最終儲存準確率最高的一次模型 if test_acc>max_acc: max_acc = test_acc saver.save(sess,'./ckpt/mnist.ckpt',global_step=i+1) #讀取模型日誌檔案進行測試 else: model_file = tf.train.latest_checkpoint('./ckpt/') saver.restore(sess,model_file) test_acc = sess.run(label_acc, feed_dict={model.images: mnist.test.images, model.labels: mnist.test.labels}) print('test_acc: {}'.format(test_acc))