如何使用訓練好的tensorflow
阿新 • • 發佈:2018-12-26
from tensorflow.python.framework import graph_util
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
這兩句是重要的程式碼,用來把訓練好的模型儲存為pb檔案。執行完之後就會發現應該的資料夾多出了一個pb檔案。
- test
def recognize(jpg_path, pb_file_path):
with tf.Graph().as_default():
output_graph_def = tf.GraphDef()
with open(pb_file_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
開啟相應的pb檔案。
img = io.imread(jpg_path) img = transform.resize(img, (224, 224, 3)) img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})
讀取圖片檔案,resize之後放入模型的輸入位置,之後img_out_softmax就是相應輸出的結果。