如何檢視 TensorFlow SavedModel 格式模型的資訊
在《 ofollow,noindex">Tensorflow SavedModel模型的儲存與載入 》一文中,我們談到SavedModel格式的優點是與語言無關、容易部署和載入。那問題來了,如果別人釋出了一個SavedModel模型,我們該如何去了解這個模型,如何去載入和使用這個模型呢?
理想的狀態是模型釋出者編寫出完備的文件,給出示例程式碼。但在很多情況下,我們只是得到了訓練好的模型,而沒有齊全的文件,這個時候我們能否從模型本身上獲得一些資訊呢?比如模型的輸入輸出、模型的結構等等。
答案是可以的。
檢視模型的Signature簽名
這裡的簽名,並非是為了保證模型不被修改的那種電子簽名。我的理解是類似於程式語言中模組的輸入輸出資訊,比如函式名,輸入引數型別,輸出引數型別等等。我們以《 Tensorflow SavedModel模型的儲存與載入 》裡的模型程式碼為例,從語句:
signature = predict_signature_def(inputs={'myInput': x}, outputs={'myOutput': y})
我們可以看到模型的輸入名為myInput,輸出名為myOutput。如果我們沒有原始碼呢?
Tensorflow提供了一個工具,如果你下載了Tensorflow的原始碼,可以找到這樣一個檔案,./tensorflow/python/tools/saved_model_cli.py,你可以加上-h引數檢視該指令碼的幫助資訊:
usage: saved_model_cli.py [-h] [-v] {show,run,scan} ... saved_model_cli: Command-line interface for SavedModel optional arguments: -h, --helpshow this help message and exit -v, --versionshow program's version number and exit commands: valid commands {show,run,scan}additional help
指定SavedModel模所在的位置,我們就可以顯示SavedModel的模型資訊:
python $TENSORFLOW_DIR/tensorflow/python/tools/saved_model_cli.py show --dir ./model/ --all
結果為:
MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['predict']: The given SavedModel SignatureDef contains the following input(s): inputs['myInput'] tensor_info: dtype: DT_FLOAT shape: (-1, 784) name: myInput:0 The given SavedModel SignatureDef contains the following output(s): outputs['myOutput'] tensor_info: dtype: DT_FLOAT shape: (-1, 10) name: Softmax:0 Method name is: tensorflow/serving/predict
從這裡我們可以清楚的看到模型的輸入/輸出的名稱、資料型別、shape以及方法名稱。有了這些資訊,我們就可以很容易寫出推斷方法。
檢視模型的計算圖
瞭解tensflow的人可能知道TensorBoard是一個非常強大的工具,能夠顯示很多模型資訊,其中包括計算圖。問題是,TensorBoard需要模型訓練時的log,如果這個SavedModel模型是別人訓練好的呢?辦法也不是沒有,我們可以寫一段程式碼,載入這個模型,然後輸出summary info,程式碼如下:
import tensorflow as tf import sys from tensorflow.python.platform import gfile from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.util import compat with tf.Session() as sess: model_filename ='./model/saved_model.pb' with gfile.FastGFile(model_filename, 'rb') as f: data = compat.as_bytes(f.read()) sm = saved_model_pb2.SavedModel() sm.ParseFromString(data) if 1 != len(sm.meta_graphs): print('More than one graph found. Not sure which to write') sys.exit(1) g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def) LOGDIR='./logdir' train_writer = tf.summary.FileWriter(LOGDIR) train_writer.add_graph(sess.graph) train_writer.flush() train_writer.close()
程式碼中,將彙總資訊輸出到logdir,接著啟動TensorBoard,加上上面的logdir:
tensorboard --logdir ./logdir
在瀏覽器中輸入地址: http://127.0.0.1:6006/ ,就可以看到如下的計算圖:
小結
按照前面兩種方法,我們可以對Tensorflow SavedModel格式的模型有比較全面的瞭解,即使模型訓練者並沒有給出文件。有了這些模型資訊,相信你寫出使用模型進行推斷更加容易。
本文的完整程式碼請參考:https://github.com/mogoweb/aiexamples/tree/master/tensorflow/saved_model
希望這篇文章對您有幫助,感謝閱讀!