1. 程式人生 > >Semantic Segmentation DeepLab v3 讀取資料集(TFRecord)程式碼詳解

Semantic Segmentation DeepLab v3 讀取資料集(TFRecord)程式碼詳解

本文主要介紹谷歌官方在Github TensorFlow中開源的官方程式碼DeepLab在讀取TFRecord格式資料集所使用的方法。

配置DeepLab v3

首先,需要將整個工程拉取到本地的workspace。

2. 將原始碼拉取到自己的workspace中。

git clone https://github.com/tensorflow/models.git

3. 測試是否安裝配置成功。

# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
# From tensorflow/models/research/
python deeplab/model_test.py

讀取資料集程式碼分析

讀取資料集部分的程式碼出現在以下檔案中,以PASCAL_VOC的TFRecord格式資料集進行訓練過程為例:

(1)deeplab/train.py

(2)deeplab/datasets/segmentation_dataset.py

(3)deeplab/utils/input_generator.py

1. 輸入指令引數

在train.py中可看到以下程式碼,即需要輸入3個引數:train_logdir、tf_initial_checkpoint、dataset_dir。

if __name__ == '__main__':
  flags.mark_flag_as_required('train_logdir')
  flags.mark_flag_as_required('tf_initial_checkpoint')
  flags.mark_flag_as_required('dataset_dir')
  tf.app.run()

其中

train_logdir="/deeplab/datasets/pascal_voc_seg/exp/train_on_train_set/train"(訓練結束後的checkpoint存放路徑)

tf_initial_checkpoint="/deeplab/datasets/cityscapes/deeplabv3_mnv2_pascal_trainval/ model.ckpt-30000.index"(預訓練好的checkpoint路徑)

dataset_dir="/deeplab/datasets/pascal_voc_seg/tfrecord"(資料集路徑)

2. 通過指令輸入的引數,獲得一個slim.Dataset的例項

2.1 呼叫segmentation_dataset.py中的get_dataset()函式。

dataset = segmentation_dataset.get_dataset(
     FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)

輸入引數如下:

FLAGS.dataset= 'pascal_voc_seg'

FLAGS.train_split= 'train'

FLAGS.dataset_dir='/deeplab/datasets/pascal_voc_seg/tfrecord'(即在1中輸入的dataset_dir引數)

2.2 在segmentation_dataset.py中的get_dataset()函式,定義如下:

def get_dataset(dataset_name, split_name, dataset_dir):

(1)首先,進行兩個判斷。輸入的引數中,dataset_name必須是pascal_voc_seg、cityscapes、ade20k其中的一個,否則報錯;接著獲取資料集的基本資訊,如果輸入的split_name不是train、train_aug、trainval、val其中的一個,則報錯。

if dataset_name not in _DATASETS_INFORMATION:
    raise ValueError('The specified dataset is not supported yet.')

  splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes

  if split_name not in splits_to_sizes:
    raise ValueError('data split name %s not recognized' % split_name)

PASCAL_VOC資料集的基本資訊如下:

_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
    splits_to_sizes={
        'train': 1464,
        'train_aug': 10582,
        'trainval': 2913,
        'val': 1449,
    },
    num_classes=21,
    ignore_label=255,
)

(2)接著獲取得到num_classes = 21, ignore_label = 255。

num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
  ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label

(3)接下來,獲得資料格式file_pattern,由兩部分拼接而成。因為tfrecord格式的命名為train-*,所以file_pattern=/deeplab/datasets/pascal_voc_seg/tfrecord/train-*。

file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)

(4)宣告TF-Examples的解碼方式。

keys_to_features = {
      'image/encoded': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/filename': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/format': tf.FixedLenFeature(
          (), tf.string, default_value='jpeg'),
      'image/height': tf.FixedLenFeature(
          (), tf.int64, default_value=0),
      'image/width': tf.FixedLenFeature(
          (), tf.int64, default_value=0),
      'image/segmentation/class/encoded': tf.FixedLenFeature(
          (), tf.string, default_value=''),
      'image/segmentation/class/format': tf.FixedLenFeature(
          (), tf.string, default_value='png'),
  }
  items_to_handlers = {
      'image': tfexample_decoder.Image(
          image_key='image/encoded',
          format_key='image/format',
          channels=3),
      'image_name': tfexample_decoder.Tensor('image/filename'),
      'height': tfexample_decoder.Tensor('image/height'),
      'width': tfexample_decoder.Tensor('image/width'),
      'labels_class': tfexample_decoder.Image(
          image_key='image/segmentation/class/encoded',
          format_key='image/segmentation/class/format',
          channels=1),
  }

(5)將宣告好的兩個dict輸入TFExampleDecoder。

decoder = tfexample_decoder.TFExampleDecoder(
      keys_to_features, items_to_handlers)

(6)最後,返回一個slim.Dataset的例項。

return dataset.Dataset(
      data_sources=file_pattern,
      reader=tf.TFRecordReader,
      decoder=decoder,
      num_samples=splits_to_sizes[split_name],
      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
      ignore_label=ignore_label,
      num_classes=num_classes,
      name=dataset_name,
      multi_label=True)

其中具體的引數值如下:

file_pattern: /deeplab/datasets/pascal_voc_seg/tfrecord/train-*

tf.TFRecordReader: tf.TFRecordReader(讀取方式)

decoder: decoder

splits_to_sizes[split_name]: 1464(samples的個數)

_ITEMS_TO_DESCRIPTIONS: 一個dict,包含一些描述,如下:

_ITEMS_TO_DESCRIPTIONS = {
    'image': 'A color image of varying height and width.',
    'labels_class': ('A semantic segmentation label whose size matches image.'
                     'Its values range from 0 (background) to num_classes.'),
}

ignore_label: 255

num_classes: 21(包含一個背景)

dataset_name: pascal_voc_seg

multi_label: True

該slim.dataset的例項,也即2.1中的train.py通過呼叫該函式得到的dataset。

3. 獲得由tf.train.batch()生成的例項samples

3.1 在train.py中呼叫input_generator.py中的get()函式。

samples = input_generator.get(
          dataset,
          FLAGS.train_crop_size,
          clone_batch_size,
          min_resize_value=FLAGS.min_resize_value,
          max_resize_value=FLAGS.max_resize_value,
          resize_factor=FLAGS.resize_factor,
          min_scale_factor=FLAGS.min_scale_factor,
          max_scale_factor=FLAGS.max_scale_factor,
          scale_factor_step_size=FLAGS.scale_factor_step_size,
          dataset_split=FLAGS.train_split,
          is_training=True,
          model_variant=FLAGS.model_variant)

輸入引數如下:

dataset: 上一步獲得的slim.Dataset例項dataset

FLAGS.train_crop_size: [513, 513]

clone_batch_size: 8,計算程式碼如下,train_batch_size=8, num_clones=1

clone_batch_size = FLAGS.train_batch_size // config.num_clones

FLAGS.min_resize_value: 未找到

FLAGS.max_resize_value: 未找到

FLAGS.resize_factor: 未找到

FLAGS.min_scale_factor: 0.5

FLAGS.max_scale_factor: 2.

FLAGS.scale_factor_step_size: 0.25

FLAGS.train_split: train

is_training: True

FLAGS.model_variant: xception_65

3.2 在input_generator.py中的get()函式,其定義如下:

def get(dataset,
        crop_size,
        batch_size,
        min_resize_value=None,
        max_resize_value=None,
        resize_factor=None,
        min_scale_factor=1.,
        max_scale_factor=1.,
        scale_factor_step_size=0,
        num_readers=1,
        num_threads=1,
        dataset_split=None,
        is_training=True,
        model_variant=None):

(1)首先,兩個判斷,保證明確dataset的正確劃分和model_variant的宣告。

if dataset_split is None:
    raise ValueError('Unknown dataset split.')
  if model_variant is None:
    tf.logging.warning('Please specify a model_variant. See '
                       'feature_extractor.network_map for supported model '
                       'variants.')

(2)生成一個slim.dataset_data_provider的例項,dataset為之前獲得的slim.Dataset例項,num_readers = 1,num_epochs = None,shuffle = True。

data_provider = dataset_data_provider.DatasetDataProvider(
      dataset,
      num_readers=num_readers,
      num_epochs=None if is_training else 1,
      shuffle=is_training)

(3)呼叫_get_data()函式。

image, label, image_name, height, width = _get_data(data_provider,
                                                      dataset_split)
def _get_data(data_provider, dataset_split):

_get_data()函式,通過slim.dataset_data_provider的get方法獲取到image、height、width,接著獲取到data_name。接下來,判斷是否為訓練/驗證過程,若是訓練/驗證過程,則獲取到label,否則label為None。最後,返回image、label、image_name、height、width這5個tensor給get()函式。

if common.LABELS_CLASS not in data_provider.list_items():
    raise ValueError('Failed to find labels.')

  image, height, width = data_provider.get(
      [common.IMAGE, common.HEIGHT, common.WIDTH])

  # Some datasets do not contain image_name.
  if common.IMAGE_NAME in data_provider.list_items():
    image_name, = data_provider.get([common.IMAGE_NAME])
  else:
    image_name = tf.constant('')

  label = None
  if dataset_split != common.TEST_SET:
    label, = data_provider.get([common.LABELS_CLASS])

  return image, label, image_name, height, width

(4)接著,繼續在get()函式中。判斷通過_get_data()函式返回的label是否為None,若不是None,則判斷維度是否為[, , 1]。若是2維,則擴維;若是3維且第三維是否為1,則跳過,否則報錯;最後將label維度設定為[None,None,1]。

if label is not None:
    if label.shape.ndims == 2:
      label = tf.expand_dims(label, 2)
    elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
      pass
    else:
      raise ValueError('Input label shape must be [height, width], or '
                       '[height, width, 1].')

    label.set_shape([None, None, 1])

(5)接著,呼叫input_process.py中的preprocess_image_and_label()函式,對訓練過程中用到的image和label進行操作,比如resize和歸一化等。

original_image, image, label = input_preprocess.preprocess_image_and_label(
      image,
      label,
      crop_height=crop_size[0],
      crop_width=crop_size[1],
      min_resize_value=min_resize_value,
      max_resize_value=max_resize_value,
      resize_factor=resize_factor,
      min_scale_factor=min_scale_factor,
      max_scale_factor=max_scale_factor,
      scale_factor_step_size=scale_factor_step_size,
      ignore_label=dataset["ignore_label"],
      is_training=is_training,
      model_variant=model_variant)

其中,輸入引數如下:

image: image

label: label

crop_height: 513

crop_width: 513

min_resize_value: None

max_resize_value: None

resize_factor: None

min_scale_factor: 0.5

max_scale_factor: 2.

scale_factor_step_size: 0.25

ignore_label: 255

is_training: True

model_variant: xception_65

返回值如下:

original_image: resize後的image

image: resize後的經過預處理的image

label: resize後的經過預處理的label

(6)接著,宣告一個dict例項sample。分別將image, image_name, height, width傳入。若label不是None,則將label也傳入。若非訓練過程,則將original_image也傳入用於visualization過程。

sample = {
      common.IMAGE: image,
      common.IMAGE_NAME: image_name,
      common.HEIGHT: height,
      common.WIDTH: width
  }
  if label is not None:
    sample[common.LABEL] = label

  if not is_training:
    sample[common.ORIGINAL_IMAGE] = original_image,
    num_threads = 1

(7)最後,呼叫tf.train.batch(),返回一個samples。

  return tf.train.batch(
      sample,
      batch_size=batch_size,
      num_threads=num_threads,
      capacity=32 * batch_size,
      allow_smaller_final_batch=not is_training,
      dynamic_pad=True)

4. 最後,在train.py中,呼叫slim.prefetch_queue.prefetch_queue()方法,生成輸入佇列

inputs_queue = prefetch_queue.prefetch_queue(
          samples, capacity=128 * config.num_clones)

至此,讀取資料集的過程結束。

宣告

1. 本文為作者原創,如需轉載,請註明本文連結和作者ID:superkoma。

2. 創作本文目的是為了理解DeepLab v3讀取資料集的主要流程,方便在自己的資料集上進行訓練和驗證測試等。一般大家會將自己的資料集轉為TFRecord格式以滿足輸入要求,但是另一種思路是修改其原始碼讀取資料集的部分,使其能夠直接從一個包含影象路徑的list中直接讀取。作者後續會推出修改方式。