1. 程式人生 > >tensorflow-保存與讀取使用模型

tensorflow-保存與讀取使用模型

手寫 odin nis 圖像 info one rom lac Coding

1、MNIST是深度學習的經典入門demo,他是由6萬張訓練圖片和1萬張測試圖片構成的,每張圖片都是2828大小(如下圖),而且都是黑白色構成(這裏的黑色是一個0-1的浮點數,黑色越深表示數值越靠近1),這些圖片是采集的不同的人手寫從0到9的數字。
下面先訓練識別數字模型
再保存模型
最後,讀取保存的模型,對數字圖片進行識別。

2、保存模型

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb  3 20:28:26 2019

@author: myhaspl
"""
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

import tensorflow as tf
import os

x=tf.placeholder(tf.float32,[None,784])

w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

y=tf.nn.softmax(tf.matmul(x,w)+b)
y_=tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver()
for i in range(1000):
    sampleX,sampleY=mnist.train.next_batch(100)
    sess.run(train_step,feed_dict={x:sampleX,y_:sampleY})

print("訓練完成")
print("保存生成模型...")
model_dir="mnist_model"
model_name="ml1"
if not os.path.exists(model_dir):
    os.mkdir(model_dir)

saver.save(sess,os.path.join(model_dir,model_name))
print("保存生成模型成功")  
訓練完成
保存生成模型...
保存生成模型成功
[root@VM03centos learn]# ls mnist_model 
checkpoint  ml1.data-00000-of-00001  ml1.index  ml1.meta
[root@VM03centos learn]# ls MNISTdata 
t10k-images-idx3-ubyte.gz  t10k-labels-idx1-ubyte.gz  train-images-idx3-ubyte.gz  train-labels-idx1-ubyte.gz
[root@VM03centos learn]# 

讀取數字識別模型,對某個數字圖像進行識別

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb  3 20:28:26 2019

@author: myhaspl
"""
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNIST_data/",one_hot=True)

import tensorflow as tf

x=tf.placeholder(tf.float32,[None,784])

w=tf.Variable(tf.zeros([784,10]))
b=tf.Variable(tf.zeros([10]))

y=tf.nn.softmax(tf.matmul(x,w)+b)
y_=tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init=tf.global_variables_initializer()
sess=tf.Session()
sess.run(init)
saver=tf.train.Saver()

print("讀取模型...")
saver.restore(sess,"mnist_model/ml1")
print("讀取模型完成")
print("根據模型進行計算...")
img=mnist.test.images[5]
result=sess.run(y,feed_dict={x:img.reshape(1,784)})
print("預測輸出結果:{}".format(result))
print("預測結果:{}".format(result.argmax()))
print("實際結果:{}".format(mnist.test.labels[5].argmax()))
讀取模型...
INFO:tensorflow:Restoring parameters from mnist_model/ml1
讀取模型完成
根據模型進行計算...
預測輸出結果:[[1.8999807e-06 9.8351490e-01 3.0815993e-03 4.3848301e-03 4.1427880e-05
  1.6864968e-04 7.6594086e-05 4.5587993e-03 3.2991443e-03 8.7222963e-04]]
預測結果:1
實際結果:1

tensorflow-保存與讀取使用模型