1. 程式人生 > >slim 讀取並使用預訓練模型 inception_v3 遷移學習

slim 讀取並使用預訓練模型 inception_v3 遷移學習

轉自:https://blog.csdn.net/amanfromearth/article/details/79155926#commentBox

在使用Tensorflow做讀取並finetune的時候,發現在讀取官方給的inception_v3預訓練模型總是出現各種錯誤,現記錄其正確的讀取方式和各種錯誤做法: 

關鍵程式碼如下:

import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import inception_v3
from research.slim.preprocessing import inception_preprocessing
Pretrained_model_dir = '/Users/apple/tensorflow_model/models-master/research/slim/pre_train/inception_v3.ckpt'

image_size = 299

# 讀取網路
with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
    imgPath = 'test.jpg'
    testImage_string = tf.gfile.FastGFile(imgPath, 'rb').read()
    testImage = tf.image.decode_jpeg(testImage_string, channels=3)
    processed_image = inception_preprocessing.preprocess_image(testImage, image_size, image_size, is_training=False)
    processed_images = tf.expand_dims(processed_image, 0)
    logits, end_points = inception_v3.inception_v3(processed_images, num_classes=128, is_training=False)
    w1 = tf.Variable(tf.truncated_normal([128, 5], stddev=tf.sqrt(1/128)))
    b1 = tf.Variable(tf.zeros([5]))
    logits = tf.nn.leaky_relu(tf.matmul(logits, w1) + b1)

with tf.Session() as sess:
     # 先初始化所有變數,避免有些變數未讀取而產生錯誤
     init = tf.global_variables_initializer()
     sess.run(init)
     #載入預訓練模型
     print('Loading model check point from {:s}'.format(Pretrained_model_dir))

     #這裡的exclusions是不需要讀取預訓練模型中的Logits,因為預設的類別數目是1000,當你的類別數目不是1000的時候,如果還要讀取的話,就會報錯
     exclusions = ['InceptionV3/Logits',
                   'InceptionV3/AuxLogits']
     #建立一個列表,包含除了exclusions之外所有需要讀取的變數
     inception_except_logits = slim.get_variables_to_restore(exclude=exclusions)
     #建立一個從預訓練模型checkpoint中讀取上述列表中的相應變數的引數的函式
     init_fn = slim.assign_from_checkpoint_fn(Pretrained_model_dir, inception_except_logits,ignore_missing_vars=True)
     #執行該函式
     init_fn(sess)
     print('Loaded.')
     out = sess.run(logits)
     print(out.shape)
     print(out)

其中可能會出現的錯誤如下: 
錯誤1

  • 1
  • 2
  • 3

原因: 
預訓練模型中的類別數class_num=1000,這裡輸入的class_num=5,當讀取完整模型的時候當然會出錯。 
解決方案: 
選擇不讀取包含類別數的Logits層和AuxLogits層:

  • 1
  • 2

錯誤2 
Tensor name “xxxx” not found in checkpoint files 

  • 1
  • 2
  • 3
  • 4

這裡的Tensor name可以是所有inception_v3中變數的名字,出現這種情況的各種原因和解決方案是: 
1.建立圖的時候沒有用arg_scope,是這樣建立的:

  • 1

解決方案: 
在這裡加上arg_scope,裡面呼叫的是庫中自帶的inception_v3_arg_scope

  • 1
  • 2

2.在讀取checkpoint的時候未初始化所有變數,即未執行

  • 1
  • 2

這樣會導致有一些checkpoint中不存在的變數未被初始化,比如使用Momentum時的每一層的Momentum引數等。

3.使用slim.assign_from_checkpoint_fn()函式時,沒有新增ignore_missing_vars=True屬性,由於預設ignore_missing_vars=False,所以,當使用非SGD的optimizer的時候(如Momentum、RMSProp等)時,會提示Momentum或者RMSProp的引數在checkpoint中無法找到,如: 
使用Momentum時:

  • 1
  • 2
  • 3
  • 4

使用RMSProp時:

  • 1
  • 2
  • 3
  • 4

解決方法很簡單,就是把ignore_missing_vars=True

  • 1

注意:一定要在之前的步驟都完成之後才能設成True,不然如果變數名稱全部出錯的話,會忽視掉checkpoint中所有的變數,從而不讀取任何引數。

以上就是我碰見的問題,希望有所幫助。