tensorflow從已經訓練好的模型中,恢復(指定)權重(構建新變數、網路)並繼續訓練(finetuning)
假如要儲存或者恢復指定tensor,並且把儲存的graph恢復(插入)到當前的graph中呢?
總的來說,目前我會的是兩種方法,命名都是很關鍵!
兩種方式儲存模型,
1.儲存所有tensor,即整張圖的所有變數,
2.只儲存指定scope的變數
兩種方式恢復模型,
1.匯入模型的graph,用該graph的saver來restore變數
2.在新的程式碼段中寫好同樣的模型(變數名稱及scope的name要對應),用預設的graph的saver來restore指定scope的變數
兩種儲存方式:
1.儲存整張圖,所有變數
... init = tf.global_variables_initializer() saver = tf.train.Saver() config = tf.ConfigProto() config.gpu_options.allow_growth=True with tf.Session(config=config) as sess: sess.run(init) ... writer.add_graph(sess.graph) ... saved_path = saver.save(sess,saver_path) ...
2.儲存圖中的部分變數
... init = tf.global_variables_initializer() vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')#獲取指定scope的tensor saver = tf.train.Saver(vgg_ref_vars)#初始化saver時,傳入一個var_list的引數 config = tf.ConfigProto() config.gpu_options.allow_growth=True with tf.Session(config=config) as sess: sess.run(init) ... writer.add_graph(sess.graph) ... saved_path = saver.save(sess,saver_path) ...
兩種恢復方式:
1.匯入graph來恢復
... vgg_meta_path = params['vgg_meta_path'] # 字尾是'.ckpt.meta'的檔案 vgg_graph_weight = params['vgg_graph_weight'] # 字尾是'.ckpt'的檔案,裡面是各個tensor的值 saver_vgg = tf.train.import_meta_graph(vgg_meta_path) # 匯入graph到當前的預設graph中,返回匯入graph的saver x_vgg_feat = tf.get_collection('inputs_vgg')[0] #placeholder, [None, 4096],獲取輸入的placeholder feat_decode = tf.get_collection('feat_encode')[0] #[None, 1024],獲取要使用的tensor """ 以上兩個獲取tensor的方式也可以為: graph = tf.get_default_graph() centers = graph.get_tensor_by_name('loss/intra/center_loss/centers:0') 當然,前提是有tensor的名字 """ ... init = tf.global_variables_initializer() saver = tf.train.Saver() # 這個是當前新圖的saver config = tf.ConfigProto() config.gpu_options.allow_growth=True with tf.Session(config=config) as sess: sess.run(init) ... saver_vgg.restore(sess, vgg_graph_weight)#使用匯入圖的saver來恢復 ...
2.重寫一樣的graph,然後恢復指定scope的變數
def re_build():#重建儲存的那個graph
with tf.variable_scope('vgg_feat_fc'): #沒錯,這個scope要和需要恢復模型中的scope對應
...
return ...
...
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc') # 既然有這個scope,其實第1個方法中,匯入graph後,可以不用返回的vgg_saver,再新建一個指定var_list的vgg_saver就好了,恩,需要傳入一個var_list的引數
...
init = tf.global_variables_initializer()
saver_vgg = tf.train.Saver(vgg_ref_vars) # 這個是要恢復部分的saver
saver = tf.train.Saver() # 這個是當前新圖的saver
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init)
...
saver_vgg.restore(sess, vgg_graph_weight)#使用匯入圖的saver來恢復
...
總結一下,這裡的要點就是,在restore的時候,saver要和模型對應,如果直接用當前graph的saver = tf.train.Saver(),來恢復儲存模型的權重saver.restore(vgg_graph_weight),就會報錯,提示key/tensor ... not found之類的錯誤;
寫graph的時候,一定要注意寫好scope和tensor的name,合理插入variable_scope;
最方便的方式還是,用第1種方式來儲存模型,這樣就不用重寫程式碼段了,然後第1種方式恢復,不過為了穩妥,最好還是通過獲取var_list,指定saver的var_list,妥妥的!
最新發現,用第1種方式恢復時,要記得當前的graph和儲存的模型中沒有重名的tensor,否則當前graph的tensor name可能不是那個name,可能在後面加了"_1"....-_-||
在恢復圖基礎上構建新的網路(變數)並訓練(finetuning)(2017.11.9更新)
恢復模型graph和weights在上面已經說了,這裡的關鍵點是怎樣只恢復原圖的權重 ,並且使optimizer只更新新構造變數(指定層、變數)。
(以下code與上面沒聯絡)
"""1.Get input, output , saver and graph"""#從匯入圖中獲取需要的東西
meta_path_restore = model_dir + '/model_'+model_version+'.ckpt.meta'
model_path_restore = model_dir + '/model_'+model_version+'.ckpt'
saver_restore = tf.train.import_meta_graph(meta_path_restore) #獲取匯入圖的saver,便於後面恢復
graph_restore = tf.get_default_graph() #此時預設圖就是匯入的圖
#從匯入圖中獲取需要的tensor
#1. 用collection來獲取
input_x = tf.get_collection('inputs')[0]
input_is_training = tf.get_collection('is_training')[0]
output_feat_fused = tf.get_collection('feat_fused')[0]
#2. 用tensor的name來獲取
input_y = graph_restore.get_tensor_by_name('label_exp:0')
print('Get tensors...')
print('inputs shape: {}'.format(input_x.get_shape().as_list()))
print('input_is_training shape: {}'.format(input_is_training.get_shape().as_list()))
print('output_feat_fused shape: {}'.format(output_feat_fused.get_shape().as_list()))
"""2.Build new variable for fine tuning"""#構造新的variables用於後面的finetuning
graph_restore.clear_collection('feat_fused') #刪除以前的集合,假如finetuning後用新的代替原來的
graph_restore.clear_collection('prob')
#新增新的東西
if F_scale is not None and F_scale!=0:
print('F_scale is not None, value={}'.format(F_scale))
feat_fused = Net_normlize_scale(output_feat_fused, F_scale)
tf.add_to_collection('feat_fused',feat_fused)#重新新增到新集合
logits_fused = last_logits(feat_fused,input_is_training,7) # scope name是"final_logits"
"""3.Get acc and loss"""#構造損失
with tf.variable_scope('accuracy'):
accuracy,prediction = ...
with tf.variable_scope('loss'):
loss = ...
"""4.Build op for fine tuning"""
global_step = tf.Variable(0, trainable=False,name='global_step')
learning_rate = tf.train.exponential_decay(initial_lr,
global_step=global_step,
decay_steps=decay_steps,
staircase=True,
decay_rate=0.1)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
var_list = tf.contrib.framework.get_variables('final_logits')#關鍵!獲取指定scope下的變數
train_op = tf.train.MomentumOptimizer(learning_rate=learning_rate,momentum=0.9).minimize(loss,global_step=global_step,var_list=var_list) #只更新指定的variables
"""5.Begin training"""
init = tf.global_variables_initializer()
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init)
saver_restore.restore(sess, model_path_restore) #這裡saver_restore對應匯入圖的saver, 如果用上面新的saver的話會報錯 因為多出了一些新的變數 在儲存的模型中是沒有那些權值的
sess.run(train_op, feed_dict)
.......
再說明下兩個關鍵點:
1. 如何在新圖的基礎上 只恢復 匯入圖的權重 ?
用匯入圖的saver: saver_restore
2. 如何只更新指定引數?
用var_list = tf.contrib.framework.get_variables(scope_name)獲取指定scope_name下的變數,
然後optimizer.minimize()時傳入指定var_list
附:如何知道tensor名字以及獲取指定變數?
1.獲取某個操作之後的輸出
用graph.get_operations()獲取所有op
比如<tf.Operation 'common_conv_xxx_net/common_conv_net/flatten/Reshape' type=Reshape>,
那麼output_pool_flatten = graph_restore.get_tensor_by_name('common_conv_xxx_net/common_conv_net/flatten/Reshape:0')就是那個位置經過flatten後的輸出了
2.獲取指定的var的值
用GraphKeys獲取變數
tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)返回指定集合的變數
比如 <tf.Variable 'common_conv_xxx_net/final_logits/logits/biases:0' shape=(7,) dtype=float32_ref>
那麼var_logits_biases = graph_restore.get_tensor_by_name('common_conv_xxx_net/final_logits/logits/biases:0')就是那個位置的biases了
3.獲取指定scope的collection
tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES,scope='common_conv_xxx_net.final_logits')
注意後面的scope是xxx.xxx不是xxx/xxx