Tensorflow讀取並使用預訓練模型:以inception_v3為例
在使用Tensorflow做讀取並finetune的時候,發現在讀取官方給的inception_v3預訓練模型總是出現各種錯誤,現記錄其正確的讀取方式和各種錯誤做法:
關鍵程式碼如下:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
.....................................................
# 讀取網路
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)
....................................................
with tf.Session() as sess:
# 先初始化所有變數,避免有些變數未讀取而產生錯誤
init = tf.global_variables_initializer()
sess.run(init)
#載入預訓練模型
print('Loading model check point from {:s}'.format(Pretrained_model_dir))
#這裡的exclusions是不需要讀取預訓練模型中的Logits,因為預設的類別數目是1000,當你的類別數目不是1000的時候,如果還要讀取的話,就會報錯
exclusions = ['InceptionV3/Logits' ,
'InceptionV3/AuxLogits']
#建立一個列表,包含除了exclusions之外所有需要讀取的變數
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
#建立一個從預訓練模型checkpoint中讀取上述列表中的相應變數的引數的函式
init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True )
#執行該函式
init_fn(sess)
print('Loaded.')
其中的…………………………..省略了一些與本文無關的程式碼。
其中可能會出現的錯誤如下:
錯誤1
InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [5] rhs shape= [1001]
[[Node: save_1/Assign_8 = Assign[T=DT_FLOAT, _class=["loc:@InceptionV3/AuxLogits/Conv2d_2b_1x1/biases"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/device:GPU:0"](InceptionV3/AuxLogits/Conv2d_2b_1x1/biases, save_1/RestoreV2_8/_2319)]]
原因:
預訓練模型中的類別數class_num=1000,這裡輸入的class_num=5,當讀取完整模型的時候當然會出錯。
解決方案:
選擇不讀取包含類別數的Logits層和AuxLogits層:
exclusions = ['InceptionV3/Logits','InceptionV3/AuxLogits']
inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
錯誤2
Tensor name “xxxx” not found in checkpoint files
NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6c/Branch_2/Conv2d_0b_7x1/biases" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
[[Node: save_1/RestoreV2_180 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_180/tensor_names, save_1/RestoreV2_180/shape_and_slices)]]
[[Node: save_1/RestoreV2_277/_109 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_854_save_1/RestoreV2_277", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
這裡的Tensor name可以是所有inception_v3中變數的名字,出現這種情況的各種原因和解決方案是:
1.建立圖的時候沒有用arg_scope,是這樣建立的:
logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)
解決方案:
在這裡加上arg_scope,裡面呼叫的是庫中自帶的inception_v3_arg_scope
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
logits, end_points = inception_v3.inception_v3(imgs, num_classes=class_num, is_training=is_training_pl)
2.在讀取checkpoint的時候未初始化所有變數,即未執行
init = tf.global_variables_initializer()
sess.run(init)
這樣會導致有一些checkpoint中不存在的變數未被初始化,比如使用Momentum時的每一層的Momentum引數等。
3.使用slim.assign_from_checkpoint_fn()
函式時,沒有新增ignore_missing_vars=True
屬性,由於預設ignore_missing_vars=False,所以,當使用非SGD的optimizer的時候(如Momentum、RMSProp等)時,會提示Momentum或者RMSProp的引數在checkpoint中無法找到,如:
使用Momentum時:
NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6e/Branch_2/Conv2d_0c_1x7/BatchNorm/beta/Momentum" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
[[Node: save_1/RestoreV2_397 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_397/tensor_names, save_1/RestoreV2_397/shape_and_slices)]]
[[Node: save_1/RestoreV2_122/_2185 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_2096_save_1/RestoreV2_122", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
使用RMSProp時:
NotFoundError (see above for traceback): Tensor name "InceptionV3/Mixed_6b/Branch_1/Conv2d_0b_1x7/BatchNorm/beta/RMSProp" not found in checkpoint files E:\DeepLearning\TensorFlow\Inception\inception_v3_2016_08_28\inception_v3.ckpt
[[Node: save_1/RestoreV2_257 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save_1/Const_0_0, save_1/RestoreV2_257/tensor_names, save_1/RestoreV2_257/shape_and_slices)]]
[[Node: save_1/Assign_463/_3950 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device_incarnation=1, tensor_name="edge_3478_save_1/Assign_463", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
解決方法很簡單,就是把ignore_missing_vars=True
init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
注意:一定要在之前的步驟都完成之後才能設成True,不然如果變數名稱全部出錯的話,會忽視掉checkpoint中所有的變數,從而不讀取任何引數。
以上就是我碰見的問題,希望有所幫助。