1. 程式人生 > >tensorflow學習系列六:mnist從訓練儲存模型再到載入模型測試

tensorflow學習系列六:mnist從訓練儲存模型再到載入模型測試

    通過前面幾個系列的學習對tensorflow有了一個漸漸親切的感覺,本文主要是從tensorflow模型訓練與驗證的模型進行實踐一遍,以至於我們能夠通過tensorflow的訓練有一個整體的概念。下面主要是從訓練到儲存模型,然後載入模型進行預測。

# -*- coding: utf-8 -*-
"""
Created on Mon Jun 11 22:17:52 2018

func:搭建網路圖

@author: kuangyongjian
"""
import tensorflow as tf


#構建圖
class Network(object):
    
    def __init__(self):
        
        self.learning_rate = 0.001
        #機率已經訓練的次數
        self.global_step = tf.Variable(0,trainable = False)
        
        self.x = tf.placeholder(tf.float32,[None,784])
        self.label = tf.placeholder(tf.float32,[None,10])
        
        self.w = tf.Variable(tf.zeros([784,10]))
        self.b = tf.Variable(tf.zeros([10]))
        self.y = tf.nn.softmax(tf.matmul(self.x,self.w) + self.b)
        
        self.loss = -tf.reduce_mean(self.label * tf.log(self.y) + 1e-10)
        
        self.train = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss,global_step = self.global_step)
        
        predict = tf.equal(tf.argmax(self.label,1),tf.argmax(self.y,1))
        self.accuracy = tf.reduce_mean(tf.cast(predict,tf.float32))
    
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 09:16:52 2018

func:網路訓練,以及對應的模型儲存

@author: kuangyongjian
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from model import Network

CKPT_DIR = 'ckpt'

class Train(object):
    
    def __init__(self):
        self.net = Network()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        self.data = input_data.read_data_sets('../data_set',one_hot = True)
        
    def train(self):
        batch_size = 64
        train_step = 10000
        step = 0
        #每隔1000步儲存一次模型
        save_interval = 1000
        
        #tf.train.Saver用於儲存訓練的結果
        #max to keep 用於設定最多儲存多少個模型
        #如果儲存的模型超過這個值,最舊的模型被刪除
        saver = tf.train.Saver(max_to_keep = 10)
        
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        if ckpt and ckpt.get_checkpoint_state(CKPT_DIR):
            saver.restore(self.sess,ckpt.model_checkpoint_path)
            #讀取網路中的global_step的值,即當前已經訓練的次數
            step = self.sess.run(self.net.global_step)
            print('continue from')
            print('  -> Minibatch update : ',step)
            
        while step < train_step:
            x,label = self.data.train.next_batch(batch_size)
            _,loss = self.sess.run([self.net.train,self.net.loss],
                                   feed_dict = {self.net.x: x,self.net.label:label})
            
            step = self.sess.run(self.net.global_step)
            if step % 1000 == 0:
                print('第%6d步,當前loss: %.3f'%(step,loss))
                
            #模型儲存在ckpt資料夾下
            #模型檔名最後會增加global_step的值,比如2000的模型檔名為model-2000
            if step % save_interval == 0:
                saver.save(self.sess,CKPT_DIR + '/model',global_step = step)
    
    def calculate_accuracy(self):
        test_x = self.data.test.images
        test_label = self.data.test.labels
        acc = self.sess.run(self.net.accuracy,feed_dict = {self.net.x:test_x,self.net.label:test_label})
        
        print("準確率: %.3f,共測試了%d張圖片 " % (acc, len(test_label)))
            
                
if __name__ == '__main__':
    model = Train()
    model.train()
    model.calculate_accuracy()
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 12 09:36:55 2018

func:載入模型,進行模型測試

@author: kuangyongjian
"""
import tensorflow as tf
import numpy as np
from PIL import Image
from model import Network

CKPT_DIR = 'ckpt'


class Predict(object):
    
    def __init__(self):
        #清除預設圖的堆疊,並設定全域性圖為預設圖
        #若不進行清楚則在第二次載入的時候報錯,因為相當於重新載入了兩次
        tf.reset_default_graph() 
        self.net = Network()
        self.sess = tf.Session()
        self.sess.run(tf.global_variables_initializer())
        
        #載入模型到sess中
        self.restore()
        print('load susess')
    
    def restore(self):
        saver = tf.train.Saver()
        ckpt = tf.train.get_checkpoint_state(CKPT_DIR)
        print(ckpt.model_checkpoint_path)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(self.sess,ckpt.model_checkpoint_path)
        else:
            raise FileNotFoundError('未儲存模型')
        
    def predict(self,image_path):
        #讀取圖片並灰度化
        img = Image.open(image_path).convert('L')
        flatten_img = np.reshape(img,784)
        x = np.array([1 - flatten_img])
        y = self.sess.run(self.net.y,feed_dict = {self.net.x:x})
        
        print(image_path)
        print(' Predict digit',np.argmax(y[0]))
        
        
if __name__ == '__main__':
    model = Predict()
    model.predict('0.png')
    model.predict('../test_images/1.png')
    model.predict('../test_images/4.png')

注意文中儲存模型和載入模型的方式,特別是在載入模型的時候比較容易出錯。

若有不當之處請指教,謝謝!