1. 程式人生 > >第五章 MNIST數字識別問題(二)

第五章 MNIST數字識別問題(二)

AR dom post mod 變量名 elements 之前 spl with

4.1. ckpt文件保存方法

在對模型進行加載時候,需要定義出與原來的計算圖結構完全相同的計算圖,然後才能進行加載,並且不需要對定義出來的計算圖進行初始化操作。
這樣保存下來的模型,會在其文件夾下生成三個文件,分別是:
* .ckpt.meta文件,保存tensorflow模型的計算圖結構。
* .ckpt文件,保存計算圖下所有變量的取值。
* checkpoint文件,保存目錄下所有模型文件列表。

技術分享圖片
import tensorflow as tf
#保存計算兩個變量和的模型
v1 = tf.Variable(tf.random_normal([1], stddev=1, seed=1))
v2 
= tf.Variable(tf.random_normal([1], stddev=1, seed=1)) result = v1 + v2 init_op = tf.global_variables_initializer() saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) saver.save(sess, "Saved_model/model.ckpt") #加載保存了兩個變量和的模型 with tf.Session() as sess: saver.restore(sess, "Saved_model/model.ckpt")
print sess.run(result) INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt [-1.6226364] #直接加載持久化的圖。因為之前沒有導出v3,所以這裏會報錯 saver = tf.train.import_meta_graph("Saved_model/model.ckpt.meta") v3 = tf.Variable(tf.random_normal([1], stddev=1, seed=1)) with tf.Session() as sess: saver.restore(sess, "Saved_model/model.ckpt")
print sess.run(v1) print sess.run(v2) print sess.run(v3) INFO:tensorflow:Restoring parameters from Saved_model/model.ckpt [-0.81131822] [-0.81131822] # 變量重命名,這樣可以通過字典將模型保存時的變量名和需要加載的變量聯系起來 v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "other-v1") v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "other-v2") saver = tf.train.Saver({"v1": v1, "v2": v2})
View Code

4.2.1 滑動平均類的保存

技術分享圖片
import tensorflow as tf
#使用滑動平均
v = tf.Variable(0, dtype=tf.float32, name="v")
for variables in tf.global_variables(): print variables.name
    
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables(): print variables.name
v:0
v:0
v/ExponentialMovingAverage:0

#保存滑動平均模型
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    # 保存的時候會將v:0  v/ExponentialMovingAverage:0這兩個變量都存下來。
    saver.save(sess, "Saved_model/model2.ckpt")
    print sess.run([v, ema.average(v)])
10.0, 0.099999905]

#加載滑動平均模型
v = tf.Variable(0, dtype=tf.float32, name="v")

# 通過變量重命名將原來變量v的滑動平均值直接賦值給v。
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
INFO:tensorflow:Restoring parameters from Saved_model/model2.ckpt
0.0999999
View Code

4.2.2 variables_to_restore函數的使用樣例

import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
print ema.variables_to_restore()

#等同於saver = tf.train.Saver(ema.variables_to_restore())
saver = tf.train.Saver({"v/ExponentialMovingAverage": v})
with tf.Session() as sess:
    saver.restore(sess, "Saved_model/model2.ckpt")
    print sess.run(v)
{u‘v/ExponentialMovingAverage‘: <tf.Variable ‘v:0‘ shape=() dtype=float32_ref>}

4.3. pb文件保存方法

#pb文件的保存方法
import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name = "v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [‘add‘])
    with tf.gfile.GFile("Saved_model/combined_model.pb", "wb") as f:
           f.write(output_graph_def.SerializeToString())

INFO:tensorflow:Froze 2 variables.
Converted 2 variables to const ops.
------------------------------------------------------------------------
#加載pb文件
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename = "Saved_model/combined_model.pb"
   
    with gfile.FastGFile(model_filename, ‘rb‘) as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print sess.run(result)

[array([ 3.], dtype=float32)]

張量的名稱後面有:0,表示是某個計算節點的第一個輸出,而計算節點本身的名稱後是沒有:0的。

第五章 MNIST數字識別問題(二)