[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式)
阿新 • • 發佈:2018-12-22
[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式)
在上篇博文中,我們探索了TensorFlow模型引數儲存與載入實現方法採用的是儲存ckpt的方式。這篇博文我們會使用儲存為pd格式檔案來實現。
首先,我會在上篇博文基礎上,實現由ckpt檔案如何轉換為pb檔案,再去探索如何在訓練時直接儲存pb檔案,最後是如何利用pb檔案復現網路與引數完成應用預測功能。
- ckpt檔案轉換pd檔案
ckpt2pd檔案程式碼:
import tensorflow as tf
pd_dir = "././Saver/test1/pb_dir/MyModel.pb"
with tf.Session() as sess:
#載入運算圖
saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta')
#載入引數
saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir'))
graph = tf.get_default_graph()
out_graph = tf.graph_util.convert_variables_to_constants( sess,sess.graph_def,["in","out"])
saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)
print("saver path: ",saver_path)
執行結果:
saver path: ././Saver/test1/pb_dir/MyModel.pb
- 訓練儲存pd檔案
train檔案程式碼
import tensorflow as tf
pd_dir = "././Saver/test2/pb_dir/MyModel.pb"
def main():
x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in")
#x = tf.constant([[1,2]],dtype=tf.float32)
w1 = tf.get_variable("w1",dtype=tf.float32,initializer=tf.truncated_normal([2, 1], stddev=0.1))
b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1]))
y = tf.add(tf.matmul(x,w1),b1,name="out")
with tf.Session() as sess:
#獲取計算圖
graph = tf.get_default_graph()
#獲取name和ops,這次程式碼並沒有用到
ret = graph.get_operations()
r_names = []
#獲取name list
for r in ret:
r_names.append(r.name)
srun = sess.run
srun(tf.global_variables_initializer())
print("y: ",srun(y,{x:[[1,2]]}))
#存入輸入與輸出介面
out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"])
saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False)
print("saver path: ",saver_path)
if __name__ == "__main__":
main()
執行結果:
y: [[0.14729613]]
saver path: ./././Saver/test2/pb_dir/MyModel.pb
- pb檔案復現網路與引數
restore檔案程式碼
import tensorflow as tf
from saver1 import pd_dir
with tf.Session() as sess:
#用上下文管理器開啟pd檔案
with open(pd_dir,"rb") as pd_flie:
#獲取圖
graph = tf.GraphDef()
#獲取引數
graph.ParseFromString(pd_flie.read())
#引入輸入輸出介面
ins, outs = tf.import_graph_def(graph,return_elements=["in:0","out:0"])
#進行預測
print("y: ",sess.run(outs,{ins:[[1,2]]}))
執行結果:
y: [[0.14729613]]