1. 程式人生 > >TensorFlow學習筆記(5) TensorFlow模型持久化

TensorFlow學習筆記(5) TensorFlow模型持久化

TF提供了一個簡單的API來儲存和還原一個神經網路模型。這個API就是tf.train.Saver類。

下面即為儲存TensorFlow計算圖的方法(saver.save()):

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(1.0, shape=[1], name='v2'))
result = v1 + v2

#宣告tf.train.Saver類用於儲存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    #將模型儲存到model/model.ckpt檔案
    saver.save(sess, 'E:\PycharmProjects\\tensorflow_learn/model/model.ckpt')
這樣就實現了持久化一個簡單的TF模型的功能,通過saver.save函式將TF模型儲存到model.ckpt中。雖然只指定了一個檔案路徑,但是在該路徑下會出現三個檔案:
model.ckpt.meta 儲存了TF計算圖的結構
model.ckpt 儲存了TF程式中每個變數的取值
checkpoint 儲存了一個目錄下所有的模型檔案列表
在儲存了TF模型後,下面是使用 saver.restore()載入該模型:
import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(1.0, shape=[1], name='v2'))
result = v1 + v2

#宣告tf.train.Saver
saver = tf.train.Saver()

with tf.Session() as sess:
    #載入已經儲存的模型,並通過已經儲存的模型中的變數來計算
    saver.restore(sess, 'E:\PycharmProjects\\tensorflow_learn/model/model.ckpt')
    print(sess.run(result))

這兩段程式碼幾乎一樣,不同的地方在於載入模型的程式碼沒有初始化變數,而是通過已經儲存的模型載入進來。如果不想重複定義模型的結構,也可以直接將模型的結構加載出來:

import tensorflow as tf
#載入模型結構
saver = tf.train.import_meta_graph('E:\PycharmProjects\\tensorflow_learn/model/model.ckpt.meta')

with tf.Session() as sess:
    saver.restore(sess, 'E:\PycharmProjects\\tensorflow_learn/model/model.ckpt')
    #通過張量的名稱來獲取張量
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))
    #[2.]
為了儲存和載入部分變數,在宣告tf.train.Saver類時可以提供一個列表來指定儲存或者載入的變數,如 tf.train.Saver([v1]),這時就只有變數v1會被載入進來。除了可以指定被載入的變數,tf.train.Saver類也支援在儲存或載入時給變數重新命名:
import tensorflow as tf

v1 = tf.Variable(tf.constant(2.0, shape=[1], name='other-v1'))
v2 = tf.Variable(tf.constant(3.0, shape=[1], name='other-v2'))

#這裡如果直接使用tf.train.Saver()來載入模型會報錯

#使用一個字典來重新命名變數就可以載入原來的模型了
#這個字典指定了原名稱為v1的變數現在加到v1變數中
saver = tf.train.Saver({'v1': v1, 'v2': v2})
這種方式方便使用變數的滑動平均值,在載入模型時將影子變數對映到變數自身,那麼在訓練好的模型中就不需要再呼叫函式獲得變臉的滑動平均了:
import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name='v')
for variables in tf.global_variables():
    print(variables.name)
    #v:0

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
    print(variables.name)
    #v:0
    #v/ExponentialMovingAverage:0

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)

    saver.save(sess, 'E:\PycharmProjects\\tensorflow_learn/model1/model.ckpt')
    print(sess.run([v, ema.average(v)]))
    #[10.0, 0.099999905]
import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name='v')
saver = tf.train.Saver({'v/ExponentialMovingAverage': v})

with tf.Session() as sess:
    saver.restore(sess, 'E:\PycharmProjects\\tensorflow_learn/model1/model.ckpt')
    print(sess.run(v))
    #0.099999905
為了方便載入時重新命名滑動變數,tf.train.ExponentialMovingAverage類提供了variables_to_restore來生成重新命名所需要的字典,即{'v/ExponentialMovingAverage': v},因此上面的程式碼也可以改為
saver = tf.train.Saver(ema.variables_to_restore())

在TF中,提供了convert_variables_to_constant函式將計算圖中的變數及其取值通過常量的方式儲存,這樣整個TF計算圖可以統一存放在一個檔案中:

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(2.0, shape=[1], name='v2'))
result = v1 + v2

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #匯出當前計算圖GraphDef部分,只需這一層就可以完成從輸入層到輸出層的計算
    graph_def = tf.get_default_graph().as_graph_def()
    #將匯出的計算圖中的變數轉化為常量
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])

    with tf.gfile.GFile('model/combined_model.pb', 'wb') as f:
        f.write(output_graph_def.SerializeToString())
在儲存之後,當只需要得到計算圖中某個節點的取值時,就會有一個更方便的方法:
import tensorflow as tf
from tensorflow.python.framework import graph_util

with tf.Session() as sess:
    model_filename = 'model/combined_model.pb'
    #讀取檔案並解析成對應的GraphDef Protocol Buffer
    with gfile.FastGfile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        #將儲存的圖載入到當前圖中,並給定返回張量的名稱
        result = tf.import_graph_def(graph_def, return_elements=['add:0'])#add為張量的名稱
        print(sess.run(result))


源自:Tensorflow 實戰Google深度學習框架_鄭澤宇