1. 程式人生 > >16、TensorFLow 模型引數的儲存與恢復

16、TensorFLow 模型引數的儲存與恢復

最簡單的儲存和恢復模型的方法是使用tf.train.Saver()物件,它給 graph 中的所有變數,或是定義在列表裡的變數,新增 saverestore opstf.train.Saver()物件提供了方法來執行這些 ops,並指定了檢查點檔案的讀寫路徑。

一、tf.train.Saver() 類解析

tf.train.Saver(
    var_list=None,
    reshape=False,
    sharded=False,
    max_to_keep=5,
    keep_checkpoint_every_n_hours=10000.0,
    name=None
, restore_sequentially=False, saver_def=None, builder=None, defer_build=False, allow_empty=False, write_version=tf.train.SaverDef.V2, pad_step_number=False, save_relative_paths=False, filename=None )

1、初始化引數解析

  • var_list
    • specifies the variables that will be saved and restored. If None, defaults to the list of all saveable objects. It can be passed as a dict or a list
      :
    • A dict of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files.
    • A list of variables: The variables will be keyed with their op name in the checkpoint files.
  • For example:
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2'
) # 1、 pass them as a list,可使用此 list 儲存或載入部分變數 saver = tf.train.Saver([v1, v2]) # 2、Pass the variables as a dict: saver = tf.train.Saver({'v1': v1, 'v2': v2}) # 3、Passing a list is equivalent to passing a dict with the variable op names as keys: saver = tf.train.Saver({v.op.name: v for v in [v1, v2]}) # 4、儲存或載入時給變數重新命名 v1 = tf.Variable(..., name='other_v1') v2 = tf.Variable(..., name='other_v2') saver = tf.train.Saver({'v1': v1, 'v2': v2}) print(v1.name) # 輸出:other-v1:0
  • max_to_keep
    • indicates the maximum number of recent checkpoint files to keep.
    • As new files are created, older files are deleted.
    • If None or 0, all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)
    • 設定max_to_keep=1則只儲存最新的 model,或者在使用save()方法儲存模型時,保持global_step=None也可以達到只儲存最新model的效果。

2、常用方法解析

# Returns a string, path at which the variables were saved.
save(
    sess,
    save_path,
    global_step=None,
    latest_filename=None,
    meta_graph_suffix='meta',
    write_meta_graph=True,
    write_state=True
)

# The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables.
restore(
    sess,
    save_path
)

二、引數的儲存與恢復

1、檢查點檔案介紹

  • 變數儲存在二進位制檔案裡,主要包含從variable names to tensor values的對映關係
  • 當你建立一個Saver物件時,你可以選擇性地為檢查點檔案中的變數挑選變數名。預設情況下,將使用每個變數tf.Variable.name 屬性的值。(這才是模型的引數,和變數名沒有半毛錢關係)
  • saver = tf.train.Saver(max_to_keep=3)時 checkpoint 儲存的檔案詳情如下:
    • 第一個檔案儲存了一個目錄下所有模型檔案路徑的列表
    • 第二個檔案儲存了我們的模型(all the values of the weights, biases, gradients and all the other variables saved)
    • 第三個檔案為索引
    • 第四個檔案為計算圖的結構,包括:all variables, operations, collections etc

這裡寫圖片描述

2、儲存變數&恢復變數

  • 可以用一個 bool 型變數 is_train 來控制訓練和驗證兩個階段,True 表示訓練,False 表示測試
  • tf.train.Saver() 類支援在恢復變數時給變數重新命名(改寫原來變數中的 name 引數)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import tensorflow as tf

# Create some variables.
w = tf.get_variable("weight", shape=[2], initializer=tf.zeros_initializer())
b = tf.get_variable("bias", shape=[3], initializer=tf.zeros_initializer())

inc_w = w.assign(w + 1)
dec_b = b.assign(b - 1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver(max_to_keep=3)

isTrain = False  # True 表示訓練,False 表示測試
train_steps = 1000
checkpoint_steps = 50
checkpoint_dir = 'checkpoint/save&restore/'
model_name = 'my_model'

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
    sess.run(init_op)
    if isTrain:
        # Do some work with the model.
        for step in range(train_steps):
            inc_w.op.run()
            dec_b.op.run()
            if (step + 1) % checkpoint_steps == 0:
                # Append the step number to the checkpoint name:
                saved_path = saver.save(
                    sess,
                    checkpoint_dir + model_name,
                    global_step=step + 1  # 設為 None 時,只儲存最新結果
                )
    else:
        print('Before restore:')
        print(sess.run(w))
        print(sess.run(b))
        ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
        # 獲取最新的 model_file
        if ckpt and ckpt.model_checkpoint_path:
            print("Success to load %s." % ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            pass
        print('After restore:')
        print(sess.run(w))
        print(sess.run(b))

# 測試結果
Before restore:
[ 0.  0.]
[ 0.  0.  0.]
Success to load checkpoint/save&restore/my_model-1000.
After restore:
[ 1000.  1000.]
[-1000. -1000. -1000.]


# 結論:restore 其實就相當於重新初始化所有的變數

# 結論分析
雖然官方文件說:restore 時不用使用 init_op 去初始化所有的變量了,但這裡為了驗證下(restore 其實就相當於重新初始化所有的變數),還是把 sess.run(init_op) 放在了if isTrain: 語句的上面(同時作用於訓練和測試階段), 從測試結果中可以驗證結論。
# 其實可以把 sess.run(init_op) 放在 if isTrain: 語句的裡面(只作用於訓練階段)

3、取得可訓練引數的值&提取某一層的特徵

sess = tf.Session()

# Returns all variables created with trainable=True in a var_list
var_list = tf.trainable_variables()

print("Trainable variables:------------------------")

# 取出所有可訓練引數的索引、形狀和名稱
for idx, v in enumerate(var_list):
     print("param {:3}: {:15}   {}".format(idx, str(v.get_shape()), v.name))


# 某網路輸出示例
Trainable variables:------------------------
  param   0: (5, 5, 3, 32)     conv2d/kernel:0
  param   1: (32,)             conv2d/bias:0
  param   2: (5, 5, 32, 64)    conv2d_1/kernel:0
  param   3: (64,)             conv2d_1/bias:0
  param   4: (3, 3, 64, 128)   conv2d_2/kernel:0
  param   5: (128,)            conv2d_2/bias:0
  param   6: (3, 3, 128, 128)   conv2d_3/kernel:0
  param   7: (128,)            conv2d_3/bias:0
  param   8: (4608, 1024)      dense/kernel:0
  param   9: (1024,)           dense/bias:0
  param  10: (1024, 512)       dense_1/kernel:0   --->dense2 層的引數
  param  11: (512,)            dense_1/bias:0
  param  12: (512, 5)          dense_2/kernel:0
  param  13: (5,)              dense_2/bias:0


# 提取最後一個全連線層的引數 W 和 b
W = sess.run(var_list[12])
b = sess.run(var_list[13])

# 提取第二個全連線層的輸出值作為特徵    
feature = sess.run(dense2, feed_dict={x:img})

三、繼續訓練&Fine-tune 某一層

1、繼續訓練(所有引數)

# 定義一個全域性物件來獲取引數的值,在程式中使用(eg:FLAGS.iteration)來引用引數
FLAGS = tf.app.flags.FLAGS


# 定義命令列引數,第一個是:引數名稱,第二個是:引數預設值,第三個是:引數描述
tf.app.flags.DEFINE_string(
    "checkpoint_dir", 
    "/path/to/checkpoint_save_dir/", 
    "Directory name to save the checkpoints [checkpoint]"
)
tf.app.flags.DEFINE_boolean(
    "continue_train", 
    False, 
    "True for continue training.[False]"
)

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    if FLAGS.continue_train:
        # 自動取得最新的 model_file
        model_file = tf.train.latest_checkpoint(checkpoint_dir)
        saver.restore(sess, model_file)
        print("Success to load %s." % model_file)

2、Fine-tune 某一層

  • 更改網路中權重和偏置的引數,把需要固定不進行訓練的變數的 trainable 引數設定為False
  • 然後再使用上面的程式碼進行繼續訓練即可
    eg: my_non_trainable = tf.get_variable("my_non_trainable", shape=(3, 3), trainable=False)
  • Restore a meta checkpoint(待總結?????)
    • use the TF helper tf.train.import_meta_graph()

四、參考資料