from tensorflow.python.framework import graph_util
 
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
            with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())

這兩句是重要的程式碼,用來把訓練好的模型儲存為pb檔案。執行完之後就會發現應該的資料夾多出了一個pb檔案。

  1. test
def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")

開啟相應的pb檔案。

            img = io.imread(jpg_path)
            img = transform.resize(img, (224, 224, 3))
            img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})

讀取圖片檔案,resize之後放入模型的輸入位置,之後img_out_softmax就是相應輸出的結果。