1. 程式人生 > >深度學習中基於tensorflow_slim進行復雜模型訓練一之tensorflow_slim基本介紹

深度學習中基於tensorflow_slim進行復雜模型訓練一之tensorflow_slim基本介紹

最近在進行微表情識別,但是目前沒有查到比較有效的模型方式,考慮使用inception_v3的模型進行開發,但是該模的構造過程比較複雜,訓練更是麻煩,因此考慮基於tensorflow_slim的模組進行二次訓練,首先介紹一下關於tensorflow_slim的基本模組。

tensorflow_slim的模組主要包括以下幾個部分deployment ,nets,dataset, preprocessing, scripts。其中scripts中主要介紹瞭如何使用各模型,相當於tensorflow_slim的使用字典。下面分別介紹剩下幾個資料夾的作用。

1. dataset:

該資料夾主要儲存了資料的讀取方式,定義了資料讀取的檔案型別是tfrecord,檔名的格式是‘flower_%s_*.tfrecord’,檔案的train部分資料的多少,檔案的validation部分資料的大小(train:3200,validation:350),以及讀取tfrecord格式資料的操作方式,下面表示tfrecord格式讀取的結構。

keys_to_features = {
    'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/format': tf.FixedLenFeature((), tf.string, default_value='jpg'),
    'image/class/label': tf.FixedLenFeature(
        [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
}

接著定義了一個decord用來解碼TFrecords形式的資料,並且將decord, redaer(tf.TFRecordReader),num, label一起放入到dataset中供後續使用。

dataset_factory:用來統一管理dataset中的資料,根據傳入的資料名呼叫對應的dataset中的指令碼。

2. preprocessing: 對資料進行預處理,該處理主要包括一些旋轉,裁剪,隨機增減光強等操作,不一樣的資料處理不一樣的資料集,最後通過preprocessing_factory控制呼叫的資料預處理型別。

 在資料處理時要先給出資料的大小,因為使用tfrecord對資料進行讀取時沒有讀取資料的大小,因此需要對讀取出的資料採用rand_crop或者tf.image.resize_image的方式指定讀取出來的資料大小,(個人建議使用tf.image.resize_image方式,因為採用rand_crop的方式有可能會出現裁剪後的照片比需要的小)

3. nets : 該資料夾裡面包含了比較經典的網路結構,有inception, Alexnet, cifarnet,  mobilenet, vgg 等一些提前訓練好的結構。其中每一個模型的結構以一個指令碼的形式存在,在每個指令碼中定義了預設的輸入大小,並且返回了模型最終的輸出和各個節點的名字對應的值。

同樣在該資料夾下也存在一個net_factory, 該指令碼的作用是根據傳入的模型名字找到對應的模型指令碼,通過傳入的資料結構(data目錄下的指令碼生成的)讀取對應的num_label確定最後輸出的大小。最終返回一個呼叫已寫好的網路結構的介面。

4. deployment 中主要是一些基於單機單GPU和單機多GPU的一些配置引數,由於本人目前使用的是單機單GPU所以沒有進行過深入的瞭解。

5 .  train_image_classifier。 該指令碼的主要作用是為了對整個模型進行訓練,其中涉及的引數較多,在本篇部落格中不一一介紹其作用,當介紹如何使用時會詳細介紹各引數的作用。在該指令碼中主要有以下幾個模組:

一、各種預處理

資料讀取:

dataset = dataset_factory.get_dataset(
    FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)

網路讀取

network_fn = nets_factory.get_network_fn(
    FLAGS.model_name,
    num_classes=(dataset.num_classes - FLAGS.labels_offset),
    weight_decay=FLAGS.weight_decay,
    is_training=True)

讀取資料的處理:

with tf.device(deploy_config.inputs_device()):
    provider = slim.dataset_data_provider.DatasetDataProvider(
        dataset,
        num_readers=FLAGS.num_readers,
        common_queue_capacity=20 * FLAGS.batch_size,
        common_queue_min=10 * FLAGS.batch_size)
    [image, label] = provider.get(['image', 'label'])
    label -= FLAGS.labels_offset
    train_image_size = FLAGS.train_image_size or network_fn.default_image_size

在DatasetDateProvider內部呼叫了parallel_read方法,該方法主要通過string_input_producer,RandomShuffleQueue,FIFOQueue的方法採用佇列的方式對tfrecorf型別的資料進行讀取。

資料的預處理
preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
    preprocessing_name,
    is_training=True)  
image = image_preprocessing_fn(image, train_image_size, train_image_size)

最後對處理好的資料呼叫batch的方式分批處理,並通過one_hot_encoding的形式對標籤進行編碼,並呼叫slim.prefetch_queue.prefetch_queue的方法將資料處理成批佇列。

二、利用網路和資料構建交叉熵(clone_fn): 該函式主要是通過處理好的資料呼叫預訓練的網路得出最後的結果並構交叉熵。

三、 通過create_clones函式構建多個clones . clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue]),在每個GPU上都進行訓練。

四、 定義優化器

learning_rate根據引數自行調整。

五、 計算中的損失和梯度:

total_loss, clones_gradients = model_deploy.optimize_clones(),該方法中回傳入需要計算的引數列表,同時
在該函式中呼叫了total_loss = tf.add_n(clones_losses, name='total_loss')表明其計算的是總損失函式。

六、 進行梯度更新:grad_updates = optimizer.apply_gradients(clones_gradients, global_step=global_step), 該方法對梯度進行了更新。compute_graddients()和apply_gradients()的方法一起使用的效果相當於minimize()

七、 構建train_op並進行train:  通過train_tensor = tf.identity(total_loss, name='train_op')構建了需要的訓練張量,通過  slim.learning.train進行了訓練。

 

八、 在train方法中通過以下幾個關鍵性的語句分別完成對資料的訓練,儲存和佇列的開啟

total_loss, should_stop = train_step_fn(sess, train_op, global_step, train_step_kwargs)
sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
threads = sv.start_queue_runners(sess)

九、 需要計算的引數列表:首先通過variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)得到模型中所有的引數,再去掉傳入的不需要計算的引數,得到需要計算的引數。

十、 模型恢復部分:主要是 _get_init_fn()函式,呼叫ckpt檔案然後通過saver.restore完成。

十一、模型轉化,在slim檔案下面存在一個export_interference_graph,該指令碼的作用是將網路結構儲存為.pb的形式,然後在通過free_graph的方法結合儲存的模型引數將其儲存為可以直接使用的.pb檔案

下圖為自己理解的流程圖,供參考。