1. 程式人生 > >Tensorflow讀取並使用預訓練模型:以inception_v3為例

Tensorflow讀取並使用預訓練模型:以inception_v3為例

在使用Tensorflow做讀取並finetune的時候,發現在讀取官方給的inception_v3預訓練模型總是出現各種錯誤,現記錄其正確的讀取方式和各種錯誤做法:
關鍵程式碼如下:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3

.....................................................

# 讀取網路
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)

....................................................

with
tf.Session() as sess: # 先初始化所有變數,避免有些變數未讀取而產生錯誤 init = tf.global_variables_initializer() sess.run(init) #載入預訓練模型 print('Loading model check point from {:s}'.format(Pretrained_model_dir)) #這裡的exclusions是不需要讀取預訓練模型中的Logits,因為預設的類別數目是1000,當你的類別數目不是1000的時候,如果還要讀取的話,就會報錯 exclusions = ['InceptionV3/Logits'
, 'InceptionV3/AuxLogits'] #建立一個列表,包含除了exclusions之外所有需要讀取的變數 inception_except_logits = slim.get_variables_to_restore(exclude=exclusions) #建立一個從預訓練模型checkpoint中讀取上述列表中的相應變數的引數的函式 init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True
) #執行該函式 init_fn(sess) print('Loaded.')

其中的…………………………..省略了一些與本文無關的程式碼。

其中可能會出現的錯誤如下:
錯誤1

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [5] rhs shape= [1001]
     [[Node: save_1/Assign_8 = Assign[T=DT_FLOAT, _class=["loc:@InceptionV3/AuxLogits/Conv2d_2b_1x1/biases"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](InceptionV3/AuxLogits/Conv2d_2b_1x1/biases, save_1/RestoreV2_8/_2319)]]

原因:
預訓練模型中的類別數class_num=1000,這裡輸入的class_num=5,當讀取完整模型的時候當然會出錯。
解決方案:
選擇不讀取包含類別數的Logits層和AuxLogits層:

exclusions = ['InceptionV3/Logits','InceptionV3/AuxLogits']
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)

錯誤2
Tensor name “xxxx” not found in checkpoint files

NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6c/Branch_2/Conv2d_0b_7x1/biases" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
     [[Node: save_1/RestoreV2_180 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_180/tensor_names, save_1/RestoreV2_180/shape_and_slices)]]
     [[Node: save_1/RestoreV2_277/_109 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_854_save_1/RestoreV2_277", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

這裡的Tensor name可以是所有inception_v3中變數的名字,出現這種情況的各種原因和解決方案是:
1.建立圖的時候沒有用arg_scope,是這樣建立的:

logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)

解決方案:
在這裡加上arg_scope,裡面呼叫的是庫中自帶的inception_v3_arg_scope

with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)

2.在讀取checkpoint的時候未初始化所有變數,即未執行

init = tf.global_variables_initializer()
sess.run(init)

這樣會導致有一些checkpoint中不存在的變數未被初始化,比如使用Momentum時的每一層的Momentum引數等。

3.使用slim.assign_from_checkpoint_fn()函式時,沒有新增ignore_missing_vars=True屬性,由於預設ignore_missing_vars=False,所以,當使用非SGD的optimizer的時候(如Momentum、RMSProp等)時,會提示Momentum或者RMSProp的引數在checkpoint中無法找到,如:
使用Momentum時:

NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6e/Branch_2/Conv2d_0c_1x7/BatchNorm/beta/Momentum" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
     [[Node: save_1/RestoreV2_397 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_397/tensor_names, save_1/RestoreV2_397/shape_and_slices)]]
     [[Node: save_1/RestoreV2_122/_2185 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_2096_save_1/RestoreV2_122", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

使用RMSProp時:

NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6b/Branch_1/Conv2d_0b_1x7/BatchNorm/beta/RMSProp" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
     [[Node: save_1/RestoreV2_257 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_257/tensor_names, save_1/RestoreV2_257/shape_and_slices)]]
     [[Node: save_1/Assign_463/_3950 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_3478_save_1/Assign_463", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]

解決方法很簡單,就是把ignore_missing_vars=True

init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)

注意:一定要在之前的步驟都完成之後才能設成True,不然如果變數名稱全部出錯的話,會忽視掉checkpoint中所有的變數,從而不讀取任何引數。

以上就是我碰見的問題,希望有所幫助。