1. 程式人生 > >儲存和載入pb模型

儲存和載入pb模型

將模型儲存為pb

import tensorflow as tf
from tensorflow.python.framework import graph_util

logdir='output/'

with tf.variable_scope('conv'):
     w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
     b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)


sess=tf.InteractiveSession()

tf.global_variables_initializer().run() # 初始化所有變數
constant_graph_w = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["conv/w"]) constant_graph_b = graph_util.convert_variables_to_constants(sess , sess.graph_def , ['conv/b']) with tf.gfile.FastGFile(logdir+'expert_graph.pb', mode='wb') as f: f.write(constant_graph_w.SerializeToString()) f.write
(constant_graph_b.SerializeToString()) sess.close()

載入pb模型

import tensorflow as tf
from tensorflow.python.framework import graph_util

logdir = 'output/'
output_graph_path = logdir+'expert_graph.pb'
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    output_graph_def = tf.GraphDef()
    with
open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") w = sess.graph.get_tensor_by_name("conv/w:0") print('w:' , w.eval()) b = sess.graph.get_tensor_by_name("conv/b:0") print('b:' , b.eval())