1. 程式人生 > >TensorFlow 儲存模型為 PB 檔案

TensorFlow 儲存模型為 PB 檔案

通常我們使用 TensorFlow時儲存模型都使用 ckpt 格式的模型檔案,使用類似的語句來儲存模型

tf.train.Saver().save(sess,ckpt_file_path,max_to_keep=4,keep_checkpoint_every_n_hours=2) 

使用如下語句來恢復所有變數資訊

saver.restore(sess,tf.train.latest_checkpoint('./ckpt'))  

但是這種方式有幾個缺點,首先這種模型檔案是依賴 TensorFlow 的,只能在其框架下使用;其次,在恢復模型之前還需要再定義一遍網路結構,然後才能把變數的值恢復到網路中。

谷歌推薦的儲存模型的方式是儲存模型為 PB 檔案,它具有語言獨立性,可獨立執行,封閉的序列化格式,任何語言都可以解析它,它允許其他語言和深度學習框架讀取、繼續訓練和遷移 TensorFlow 的模型。

它的主要使用場景是實現建立模型與使用模型的解耦, 使得前向推導 inference的程式碼統一。

另外的好處是儲存為 PB 檔案時候,模型的變數都會變成固定的,導致模型的大小會大大減小,適合在手機端執行。

具體細節
這種 PB 檔案是表示 MetaGraph 的 protocol buffer格式的檔案,MetaGraph 包括計算圖,資料流,以及相關的變數和輸入輸出signature以及 asserts 指建立計算圖時額外的檔案。

這是我找到第一個MetaGraph的解釋,比較容易懂:

When you are saving your graph, a MetaGraph is created. This is the graph itself, and all the other metadata necessary for computations in this graph, as well as some user info that can be saved and version specification.

主要使用tf.SavedModelBuilder 類來完成這個工作,並且可以把多個計算圖儲存到一個 PB 檔案中,如果有多個MetaGraph,那麼只會保留第一個 MetaGraph 的版本號,並且必須為每個MetaGraph 指定特殊的名稱 tag 用以區分,通常這個名稱 tag 以該計算圖的功能和使用到的裝置命名,比如 serving or training, CPU or GPU。

我們來看看典型的儲存 PB 檔案的程式碼:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 這裡的輸出需要加上name屬性
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # convert_variables_to_constants 需要指定output_node_names,list(),可以多個
    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])

    # 測試 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))

    # 寫入序列化的 PB 檔案
    with tf.gfile.FastGFile(pb_file_path+'model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())

    # 輸出
    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31

載入 PB 模型檔案典型程式碼:

from tensorflow.python.platform import gfile

sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='') # 匯入計算圖

# 需要有一個初始化的過程    
sess.run(tf.global_variables_initializer())

# 需要先復原變數
print(sess.run('b:0'))
# 1

# 輸入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')

op = sess.graph.get_tensor_by_name('op_to_store:0')

ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
print(ret)
# 輸出 26

另外儲存為 save model 格式也可以生成模型的 PB 檔案,並且更加簡單。

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = os.getcwd()

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 這裡的輸出需要加上name屬性
    op = tf.add(xy, b, name='op_to_store')

    sess.run(tf.global_variables_initializer())

    # 測試 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))




    # 官網有誤,寫成了 saved_model_builder  
    builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
    # 構造模型儲存的內容,指定要儲存的 session,特定的 tag, 
    # 輸入輸出資訊字典,額外的資訊
    builder.add_meta_graph_and_variables(sess,
                                       ['cpu_server_1'])


# 新增第二個 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
#  ...
#  builder.add_meta_graph([tag_constants.SERVING])
#...

builder.save()  # 儲存 PB 模型

儲存好以後到saved_model_dir目錄下,會有一個saved_model.pb檔案以及variables資料夾。顧名思義,variables儲存所有變數,saved_model.pb用於儲存模型結構等資訊。

這種方法對應的匯入模型的方法:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

pb_file_path = './'

with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, ['cpu_server_1'], './way-4savemodel') 
    #第二個引數是儲存的變數的tag
    #最後一個引數是存放model的資料夾路徑,這個資料夾中還包含variables
    sess.run(tf.global_variables_initializer())

    input_x = sess.graph.get_tensor_by_name('x:0')
    input_y = sess.graph.get_tensor_by_name('y:0')

    op = sess.graph.get_tensor_by_name('op_to_store:0')

    ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
    print(ret)
# 只需要指定要恢復模型的 session,模型的 tag,模型的儲存路徑即可,使用起來更加簡單

這樣和之前的匯入 PB 模型一樣,也是要知道tensor的name。那麼如何可以在不知道tensor name的情況下使用呢,實現徹底的解耦呢? 給add_meta_graph_and_variables方法傳入第三個引數,signature_def_map即可。
轉載:https://zhuanlan.zhihu.com/p/32887066