1. 程式人生 > >[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式)

[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式)

[TensorFlow深度學習入門]實戰八·簡便方法實現TensorFlow模型引數儲存與載入(pb方式)

在上篇博文中,我們探索了TensorFlow模型引數儲存與載入實現方法採用的是儲存ckpt的方式。這篇博文我們會使用儲存為pd格式檔案來實現。
首先,我會在上篇博文基礎上,實現由ckpt檔案如何轉換為pb檔案,再去探索如何在訓練時直接儲存pb檔案,最後是如何利用pb檔案復現網路與引數完成應用預測功能。

  • ckpt檔案轉換pd檔案

ckpt2pd檔案程式碼:

import tensorflow as tf
pd_dir = "././Saver/test1/pb_dir/MyModel.pb"
with tf.Session() as sess: #載入運算圖 saver = tf.train.import_meta_graph('./Saver/test1/checkpoint_dir/MyModel.meta') #載入引數 saver.restore(sess,tf.train.latest_checkpoint('./Saver/test1/checkpoint_dir')) graph = tf.get_default_graph() out_graph = tf.graph_util.convert_variables_to_constants(
sess,sess.graph_def,["in","out"]) saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False) print("saver path: ",saver_path)

執行結果:

saver path:  ././Saver/test1/pb_dir/MyModel.pb
  • 訓練儲存pd檔案

train檔案程式碼

import tensorflow as tf

pd_dir = "././Saver/test2/pb_dir/MyModel.pb"



def
main(): x = tf.placeholder(dtype=tf.float32,shape=[None,2],name="in") #x = tf.constant([[1,2]],dtype=tf.float32) w1 = tf.get_variable("w1",dtype=tf.float32,initializer=tf.truncated_normal([2, 1], stddev=0.1)) b1 = tf.get_variable("b1",initializer=tf.constant(.1, dtype=tf.float32, shape=[1, 1])) y = tf.add(tf.matmul(x,w1),b1,name="out") with tf.Session() as sess: #獲取計算圖 graph = tf.get_default_graph() #獲取name和ops,這次程式碼並沒有用到 ret = graph.get_operations() r_names = [] #獲取name list for r in ret: r_names.append(r.name) srun = sess.run srun(tf.global_variables_initializer()) print("y: ",srun(y,{x:[[1,2]]})) #存入輸入與輸出介面 out_graph = tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,["in","out"]) saver_path = tf.train.write_graph(out_graph,".",pd_dir,as_text=False) print("saver path: ",saver_path) if __name__ == "__main__": main()

執行結果:

y:  [[0.14729613]]
saver path:  ./././Saver/test2/pb_dir/MyModel.pb
  • pb檔案復現網路與引數

restore檔案程式碼

import tensorflow as tf
from saver1 import pd_dir

with tf.Session() as sess:
    #用上下文管理器開啟pd檔案    
    with open(pd_dir,"rb") as pd_flie:
        #獲取圖
        graph = tf.GraphDef()
        #獲取引數
        graph.ParseFromString(pd_flie.read())
        #引入輸入輸出介面
        ins, outs = tf.import_graph_def(graph,return_elements=["in:0","out:0"])
        #進行預測
        print("y: ",sess.run(outs,{ins:[[1,2]]}))

執行結果:

y:  [[0.14729613]]