1. 程式人生 > >tensorflow 恢復部分引數、載入指定引數

tensorflow 恢復部分引數、載入指定引數

多分類採用與訓練模型輸出不匹配,我們需要載入部分預訓練模型的引數。

我們先看一下如何儲存和讀入預訓練模型。

#一般實驗情況下儲存的時候,都是用的saver類來儲存,如下
saver = tf.train.Saver()
saver.save(sess,"model.ckpt")

#載入時的程式碼
saver.restore(sess,"model.ckpt")

#前面的描述相當於是儲存了所有的引數,然後載入所有的引數。
#但是目前的情況有所變化了,不能載入所有的引數,最後一層的引數不一樣了,需要隨機初始化。
#首先對每一層新增name scope,如下:

with name_scope('conv1'):
        xxx
with name_scope('conv2'):
        xxx
with name_scope('fc1'):
        xxx
with name_scope('output'):
        xxx
#然後根據變數的名字,選擇載入哪些變數,

#得到該網路中,所有可以載入的引數
variables = tf.contrib.framework.get_variables_to_restore()
#刪除output層中的引數
variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output']
#構建這部分引數的
saversaver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')

#在tensorflow中,有多種方式可以得到變數的資訊:
tf.contrib.framework.get_variables_to_restore()
tf.all_variables()tf.trainable_varialbes()

 

 

 

 

多分類採用與訓練模型輸出不匹配解決方法:

利用tf.contrib.framework.get_variables_to_restore()函式,程式碼如下

variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['resnet50/fc'])
saver = tf.train.Saver(variables_to_restore)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, param_path)

 

exclude=['resnet50/fc']表示載入預訓練引數中除了resnet50/fc這一層之外的其他所有引數。

include=["inceptionv3"]表示只加載inceptionv3這一層的所有引數。

param_path是你預訓練引數儲存地址。

注:如果不止一個層引數需要丟棄,exclue=['a', 'b']即可。調優訓練(fine_tuning)時最好把前面曾trainable設為False,只訓練最後一層。