1. 程式人生 > >TensorFlow Low-Level-APIs Save and Restore學習筆記

TensorFlow Low-Level-APIs Save and Restore學習筆記

Save and Restore

本文翻譯自tensorflow官方網站的教程,只作為個人學習筆記,請勿用作商業用途。 tf.train.Saver類提供了儲存和提取模型的方法。tf.saved_model.simple_save函式也是一種簡單的方法來儲存模型。高階APIEstimaztors會自動的儲存和提取在model_dir中的模型。

Save and restore variables

TensorFlow Variables are the best way to represent shared, persistent state manipulated by your program.tf.train.Saver構造器為計算圖中的所有節點或者是指定的節點序列新增save和resotre的ops。然後Saver類去執行這些ops,並指定儲存和恢復的路徑來實現讀寫。

Saver類能從儲存的模型中讀取出計算圖中定義的Variable。如果你載入了一個模型但是不知道怎麼利用它來構建計算圖,可以參看Overview of saving and restoring models。TensorFlow通過二進位制的檔案來儲存Variable,並在其中做了Variable的name和值的對映。

Save variables

使用tf.train.Saver()來構建一個Saver以便控制整個模型,下例是一個使用Saver來儲存Variable的例子:

# Create some variables
v1 = tf.get_variable("v1", shape=[
3], initializer=tf.zeros_initializer) v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer) inc_v1 = v1.assign(v1+1) dec_v2 = v2.assign(v2-1) # Add an op to initialize the variables init_op = tf.global_variables_initializer() # Add ops to save and restore all the variables saver =
tf.train.Saver() # Later, launh the model, initialize the variables, do some work, and save variables to disk with tf.Session() as sess: sess.run(init_op) # Do some work with the model inc_v1.op.run() dec_v2.op.run() save_path = saver.save(sess, "/temp/model.ckpt") print("Model saved in path: %s" % save_path)

Restore variables

tf.train.Saver物件不僅儲存variable,還能將variable從檔案中讀取variable。注意當從檔案中讀取variable之前你不應該初始化他們,下面是一個例子:

tf.rest_default_graph()

# Create some variables
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and do some work woth the model
with tf.Session() as sess:
    # Restore variables from disk
    saver.restore(sess, "/tmp/model.ckpt")
    print("Model resotred")
    # Check the values of the variables
    print("v1 : %s" % v1.eval())
    print("v2 : %s" % v2.eval())

Choose variables to save and restore

如果在構造tf.train.Saver()的時候沒有傳入任何引數,saver將會管理整個計算圖中的variable,每個variable儲存的時候使用的name就是構造variable時候的name。有時候給儲存的variable重新賦予name是有用的,比如在之前的計算圖中某個variable的name是weights,而你希望儲存的name是params。有時候也許只需要儲存計算圖中的一部分variable。可以通過向tf.train.Saver()傳入如下引數來實現這樣的目的:

  • 一個variable的list(他們儲存的name就是構造他們時的name)
  • 一個字典dict,key是儲存時使用的name,value是需要儲存的variable

使用之前的例子:

tf.rest_default_graph()
# Create some variables
v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)

# Add ops to save and restore only 'v2' using the name "v2"
saver = tf.train.Saver({"v2":v2})

# Use the saver object normally after that
with tf.Session() as sess:
    # Initialize v1 since the saver will not
    v1.initializer.run()
    saver.restore(sess, "/tmp/model.ckpt")

    print("v1 : %s" % v1.eval())
    print("v2 : %s" % v2.eval())

注意事項:

  • 如果需要對計算圖中不同的部分進行分別儲存可以構建多個Saver物件,同一個variable也可以通過多個Saver來儲存。
  • 如果只想在session開始的時候讀取一部分的variable,除了這些讀取的variable其他的variable需要被初始化。
  • 審查checkpoint中的variable,需要使用inspect_checkpoint庫,print_tensor_in_checkpoint_file函式
  • 預設情況下Saver使用tf.Variable.name屬性來儲存每個variable,在建立Saver物件的時候也可以對checkpoint檔案中每個variable賦予name。

Inspect variables in a checkpoint

使用inspect_checkpoint庫來審查variable,仍然使用之前的例子:

# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp

# print all tensor in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensor=True)

# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [ -1. -1. -1. -1. -1.]

# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensor=False)

# tensor_name: v1
# [ 1. 1. 1.]

# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensor=False)

# tensor_name: v2
# [ -1. -1. -1. -1. -1.]

Save and restore models

使用SaveModel可以儲存和載入模型包括variable、graph和graph的metadata。這種方法儲存模型更加自然,可以通過多個介面和SaveModel進行互動比如tf.saved_model或者高階APIs。

Build and load a SaveModel

Simple save

建立SaveModel最簡單的方法就是使用tf.save_model.simple_save函式:

simple_save(session,
            export_dir,
            inputs={"x":x, "y":y},
            outputs={"z":z})

構造SaveModel的引數定義了輸入和輸出使得它可以被伺服器直接拿來進行使用,來預測或者是進行訓練。

Manually build a SaveModel

如果你的使用環境沒辦法使用tf.saved_model.simple_save,可以使用人工的builder APIs來構造一個SaveModel。tf.saved_model.builder.SaveMOdelBuilder類提供了儲存多個MetaGraphDef的功能。一個MetaGraph是一個數據流圖,再增加和它相關的variable、assets和signatures。一個MetaGraphDef是關於MetaGraph的一個protocol buffers(序列化標準)的表示。一個signature就是計算圖的輸入和輸出。如果assets需要被儲存或者寫或者複製到磁碟,那麼可以在熟悉新增MetaGraphDef時提供這些資源。如果多個MetaGraphDef與同名資源相關聯,則只保留首個版本。每個新增到SaveModel中的MetaGraphDef都必須使用使用者定義的tag進行區分。通過指定tag來確定要restore哪個MetaGraphDef,tag一般是表示用於訓練或者是用於推斷,或者指明執行在什麼裝置上如GPU。

export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
    ...
    builder.add_meta_graph_and variables(sess,
                            [tag_constants.TRAINING],
                            signature_def_map=foo_signatures,
                            assets_collection=foo_assets,
                            strip_default_attrs=True)
...
# Add a second MetaGraphDef for inference
with tf.Session(graph=tf.Graph()) as sess:
    ...
    builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True)
...
builder.save()

Forward compatibility via strip_deefault_attrs=True

如果計算圖中的ops沒有改變,那麼以下的教程讓你的模型restore以後具有前向傳播的能力。SaveModelBuilder類允許使用者來控制是否刪去ModeDefs中的一些default的屬性。SavedModelBuilder.add_meta_graph_and_variables和SavedModelBuilder.add_meta_graph方法都可以接受一個布林變數strip_default_attrs來控制是否刪掉。

如果strip_default_attrs是False,那麼輸出的tf.MetaGraphDef將要保留它其中的所有tf.NodeDef例項中的default的屬性。如下的一些情況也會導致這種前向傳播的能力失去,具體參考compatibility guidance

Loading a SavedModel in Python

Python版本的SavedModel的loader提供了儲存和載入模型的能力,load操作需要一下的資訊:

  • 載入計算圖和variable對應的Session
  • 定位MetaGraphDef的tag
  • SavedModel的儲存路徑

在一次載入中variable,assets和signature都作為MetaGraphDef的一部分一起載入到了提供的session中。

export_dir = ...
...
with tf.Session(graph=tf.Graph()) as sess:
    tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)