TensorFlow模型檔案儲存和讀取
阿新 • • 發佈:2018-10-31
一、模型檔案的儲存
在訓練一個TensorFlow模型之後,我們可以將訓練好的模型儲存成檔案,這樣可以方便下一次對新的資料進行預測的時候直接載入訓練好的模型即可獲得結果,下面通過TensorFlow提供的tf.train.Saver函式,將一個模型儲存成檔案,一般習慣性的將TensorFlow的模型檔案命名為*.ckpt檔案。
[python] view plain copy
- <span style="font-size:14px;">import tensorflow as tf
- if __name__ == "__main__":
- #定義兩個變數
- a = tf.Variable(tf.constant(1.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(2.0,shape=[1],name=
- c = a + b
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
- #宣告一個儲存
- saver = tf.train.Saver()
- saver.save(sess,"./model.ckpt")</span>
二、模型檔案的讀取
TensorFlow對於模型檔案的讀取方式也提供了幾種方法,根據讀取不同的檔案來獲取不同的資訊。
1、載入model.ckpt檔案來初始化變數
[python] view plain copy- <span style="font-size:14px;"> a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))
- c = a + b
- saver = tf.train.Saver()
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(c))
- #[ 3.]</span>
2、載入持久化圖獲取全部變數
[python] view plain copy- <span style="font-size:14px;"> saver = tf.train.import_meta_graph("model.ckpt.meta")
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(tf.get_default_graph().get_tensor_by_name("a:0")))
- #[ 1.]
- print(sess.run(tf.get_default_graph().get_tensor_by_name("b:0")))
- #[ 2.]
- print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
- #[ 3.]</span>
3、載入指定列表變數
[python] view plain copy
- <span style="font-size:14px;"> a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))
- c = a + b
- saver = tf.train.Saver([a,b])
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(a))
- #[ 1.]
- print(sess.run(b))
- #[ 2.]</span>
[[Node: _retval_Variable_1_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_1)]],使用一個沒有初始化的變數。
4、載入變數名的重新命名
tensorfow提供了一種方法可以修改載入模型中的變數名,通過tf.train.Saver(),帶參的形式來修改變數名稱。
[python] view plain copy
- <span style="font-size:14px;"> #重新定義兩個變數v1和v2
- v1 = tf.Variable(tf.constant(3.,shape=[1]),name="v1")
- v2 = tf.Variable(tf.constant(4.,shape=[1]),name="v2")
- #將模型中的變數名a重新命名為v1,將模型中的變數名b重新命名為v2
- save = tf.train.Saver({"a":v1,"b":v2})
- sess = tf.Session()
- save.restore(sess,"model.ckpt")
- print(sess.run(v1))
- print(sess.run(v2))</span>