Semantic Segmentation DeepLab v3 讀取資料集(TFRecord)程式碼詳解
本文主要介紹谷歌官方在Github TensorFlow中開源的官方程式碼DeepLab在讀取TFRecord格式資料集所使用的方法。
配置DeepLab v3
首先,需要將整個工程拉取到本地的workspace。
2. 將原始碼拉取到自己的workspace中。
git clone https://github.com/tensorflow/models.git
3. 測試是否安裝配置成功。
# From tensorflow/models/research/
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
# From tensorflow/models/research/ python deeplab/model_test.py
讀取資料集程式碼分析
讀取資料集部分的程式碼出現在以下檔案中,以PASCAL_VOC的TFRecord格式資料集進行訓練過程為例:
(1)deeplab/train.py
(2)deeplab/datasets/segmentation_dataset.py
(3)deeplab/utils/input_generator.py
1. 輸入指令引數
在train.py中可看到以下程式碼,即需要輸入3個引數:train_logdir、tf_initial_checkpoint、dataset_dir。
if __name__ == '__main__': flags.mark_flag_as_required('train_logdir') flags.mark_flag_as_required('tf_initial_checkpoint') flags.mark_flag_as_required('dataset_dir') tf.app.run()
其中
train_logdir="/deeplab/datasets/pascal_voc_seg/exp/train_on_train_set/train"(訓練結束後的checkpoint存放路徑)
tf_initial_checkpoint="/deeplab/datasets/cityscapes/deeplabv3_mnv2_pascal_trainval/ model.ckpt-30000.index"(預訓練好的checkpoint路徑)
dataset_dir="/deeplab/datasets/pascal_voc_seg/tfrecord"(資料集路徑)
2. 通過指令輸入的引數,獲得一個slim.Dataset的例項
2.1 呼叫segmentation_dataset.py中的get_dataset()函式。
dataset = segmentation_dataset.get_dataset(
FLAGS.dataset, FLAGS.train_split, dataset_dir=FLAGS.dataset_dir)
輸入引數如下:
FLAGS.dataset= 'pascal_voc_seg'
FLAGS.train_split= 'train'
FLAGS.dataset_dir='/deeplab/datasets/pascal_voc_seg/tfrecord'(即在1中輸入的dataset_dir引數)
2.2 在segmentation_dataset.py中的get_dataset()函式,定義如下:
def get_dataset(dataset_name, split_name, dataset_dir):
(1)首先,進行兩個判斷。輸入的引數中,dataset_name必須是pascal_voc_seg、cityscapes、ade20k其中的一個,否則報錯;接著獲取資料集的基本資訊,如果輸入的split_name不是train、train_aug、trainval、val其中的一個,則報錯。
if dataset_name not in _DATASETS_INFORMATION:
raise ValueError('The specified dataset is not supported yet.')
splits_to_sizes = _DATASETS_INFORMATION[dataset_name].splits_to_sizes
if split_name not in splits_to_sizes:
raise ValueError('data split name %s not recognized' % split_name)
PASCAL_VOC資料集的基本資訊如下:
_PASCAL_VOC_SEG_INFORMATION = DatasetDescriptor(
splits_to_sizes={
'train': 1464,
'train_aug': 10582,
'trainval': 2913,
'val': 1449,
},
num_classes=21,
ignore_label=255,
)
(2)接著獲取得到num_classes = 21, ignore_label = 255。
num_classes = _DATASETS_INFORMATION[dataset_name].num_classes
ignore_label = _DATASETS_INFORMATION[dataset_name].ignore_label
(3)接下來,獲得資料格式file_pattern,由兩部分拼接而成。因為tfrecord格式的命名為train-*,所以file_pattern=/deeplab/datasets/pascal_voc_seg/tfrecord/train-*。
file_pattern = _FILE_PATTERN
file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
(4)宣告TF-Examples的解碼方式。
keys_to_features = {
'image/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/filename': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/format': tf.FixedLenFeature(
(), tf.string, default_value='jpeg'),
'image/height': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/width': tf.FixedLenFeature(
(), tf.int64, default_value=0),
'image/segmentation/class/encoded': tf.FixedLenFeature(
(), tf.string, default_value=''),
'image/segmentation/class/format': tf.FixedLenFeature(
(), tf.string, default_value='png'),
}
items_to_handlers = {
'image': tfexample_decoder.Image(
image_key='image/encoded',
format_key='image/format',
channels=3),
'image_name': tfexample_decoder.Tensor('image/filename'),
'height': tfexample_decoder.Tensor('image/height'),
'width': tfexample_decoder.Tensor('image/width'),
'labels_class': tfexample_decoder.Image(
image_key='image/segmentation/class/encoded',
format_key='image/segmentation/class/format',
channels=1),
}
(5)將宣告好的兩個dict輸入TFExampleDecoder。
decoder = tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers)
(6)最後,返回一個slim.Dataset的例項。
return dataset.Dataset(
data_sources=file_pattern,
reader=tf.TFRecordReader,
decoder=decoder,
num_samples=splits_to_sizes[split_name],
items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
ignore_label=ignore_label,
num_classes=num_classes,
name=dataset_name,
multi_label=True)
其中具體的引數值如下:
file_pattern: /deeplab/datasets/pascal_voc_seg/tfrecord/train-*
tf.TFRecordReader: tf.TFRecordReader(讀取方式)
decoder: decoder
splits_to_sizes[split_name]: 1464(samples的個數)
_ITEMS_TO_DESCRIPTIONS: 一個dict,包含一些描述,如下:
_ITEMS_TO_DESCRIPTIONS = {
'image': 'A color image of varying height and width.',
'labels_class': ('A semantic segmentation label whose size matches image.'
'Its values range from 0 (background) to num_classes.'),
}
ignore_label: 255
num_classes: 21(包含一個背景)
dataset_name: pascal_voc_seg
multi_label: True
該slim.dataset的例項,也即2.1中的train.py通過呼叫該函式得到的dataset。
3. 獲得由tf.train.batch()生成的例項samples
3.1 在train.py中呼叫input_generator.py中的get()函式。
samples = input_generator.get(
dataset,
FLAGS.train_crop_size,
clone_batch_size,
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
min_scale_factor=FLAGS.min_scale_factor,
max_scale_factor=FLAGS.max_scale_factor,
scale_factor_step_size=FLAGS.scale_factor_step_size,
dataset_split=FLAGS.train_split,
is_training=True,
model_variant=FLAGS.model_variant)
輸入引數如下:
dataset: 上一步獲得的slim.Dataset例項dataset
FLAGS.train_crop_size: [513, 513]
clone_batch_size: 8,計算程式碼如下,train_batch_size=8, num_clones=1
clone_batch_size = FLAGS.train_batch_size // config.num_clones
FLAGS.min_resize_value: 未找到
FLAGS.max_resize_value: 未找到
FLAGS.resize_factor: 未找到
FLAGS.min_scale_factor: 0.5
FLAGS.max_scale_factor: 2.
FLAGS.scale_factor_step_size: 0.25
FLAGS.train_split: train
is_training: True
FLAGS.model_variant: xception_65
3.2 在input_generator.py中的get()函式,其定義如下:
def get(dataset,
crop_size,
batch_size,
min_resize_value=None,
max_resize_value=None,
resize_factor=None,
min_scale_factor=1.,
max_scale_factor=1.,
scale_factor_step_size=0,
num_readers=1,
num_threads=1,
dataset_split=None,
is_training=True,
model_variant=None):
(1)首先,兩個判斷,保證明確dataset的正確劃分和model_variant的宣告。
if dataset_split is None:
raise ValueError('Unknown dataset split.')
if model_variant is None:
tf.logging.warning('Please specify a model_variant. See '
'feature_extractor.network_map for supported model '
'variants.')
(2)生成一個slim.dataset_data_provider的例項,dataset為之前獲得的slim.Dataset例項,num_readers = 1,num_epochs = None,shuffle = True。
data_provider = dataset_data_provider.DatasetDataProvider(
dataset,
num_readers=num_readers,
num_epochs=None if is_training else 1,
shuffle=is_training)
(3)呼叫_get_data()函式。
image, label, image_name, height, width = _get_data(data_provider,
dataset_split)
def _get_data(data_provider, dataset_split):
_get_data()函式,通過slim.dataset_data_provider的get方法獲取到image、height、width,接著獲取到data_name。接下來,判斷是否為訓練/驗證過程,若是訓練/驗證過程,則獲取到label,否則label為None。最後,返回image、label、image_name、height、width這5個tensor給get()函式。
if common.LABELS_CLASS not in data_provider.list_items():
raise ValueError('Failed to find labels.')
image, height, width = data_provider.get(
[common.IMAGE, common.HEIGHT, common.WIDTH])
# Some datasets do not contain image_name.
if common.IMAGE_NAME in data_provider.list_items():
image_name, = data_provider.get([common.IMAGE_NAME])
else:
image_name = tf.constant('')
label = None
if dataset_split != common.TEST_SET:
label, = data_provider.get([common.LABELS_CLASS])
return image, label, image_name, height, width
(4)接著,繼續在get()函式中。判斷通過_get_data()函式返回的label是否為None,若不是None,則判斷維度是否為[, , 1]。若是2維,則擴維;若是3維且第三維是否為1,則跳過,否則報錯;最後將label維度設定為[None,None,1]。
if label is not None:
if label.shape.ndims == 2:
label = tf.expand_dims(label, 2)
elif label.shape.ndims == 3 and label.shape.dims[2] == 1:
pass
else:
raise ValueError('Input label shape must be [height, width], or '
'[height, width, 1].')
label.set_shape([None, None, 1])
(5)接著,呼叫input_process.py中的preprocess_image_and_label()函式,對訓練過程中用到的image和label進行操作,比如resize和歸一化等。
original_image, image, label = input_preprocess.preprocess_image_and_label(
image,
label,
crop_height=crop_size[0],
crop_width=crop_size[1],
min_resize_value=min_resize_value,
max_resize_value=max_resize_value,
resize_factor=resize_factor,
min_scale_factor=min_scale_factor,
max_scale_factor=max_scale_factor,
scale_factor_step_size=scale_factor_step_size,
ignore_label=dataset["ignore_label"],
is_training=is_training,
model_variant=model_variant)
其中,輸入引數如下:
image: image
label: label
crop_height: 513
crop_width: 513
min_resize_value: None
max_resize_value: None
resize_factor: None
min_scale_factor: 0.5
max_scale_factor: 2.
scale_factor_step_size: 0.25
ignore_label: 255
is_training: True
model_variant: xception_65
返回值如下:
original_image: resize後的image
image: resize後的經過預處理的image
label: resize後的經過預處理的label
(6)接著,宣告一個dict例項sample。分別將image, image_name, height, width傳入。若label不是None,則將label也傳入。若非訓練過程,則將original_image也傳入用於visualization過程。
sample = {
common.IMAGE: image,
common.IMAGE_NAME: image_name,
common.HEIGHT: height,
common.WIDTH: width
}
if label is not None:
sample[common.LABEL] = label
if not is_training:
sample[common.ORIGINAL_IMAGE] = original_image,
num_threads = 1
(7)最後,呼叫tf.train.batch(),返回一個samples。
return tf.train.batch(
sample,
batch_size=batch_size,
num_threads=num_threads,
capacity=32 * batch_size,
allow_smaller_final_batch=not is_training,
dynamic_pad=True)
4. 最後,在train.py中,呼叫slim.prefetch_queue.prefetch_queue()方法,生成輸入佇列
inputs_queue = prefetch_queue.prefetch_queue(
samples, capacity=128 * config.num_clones)
至此,讀取資料集的過程結束。
宣告
1. 本文為作者原創,如需轉載,請註明本文連結和作者ID:superkoma。
2. 創作本文目的是為了理解DeepLab v3讀取資料集的主要流程,方便在自己的資料集上進行訓練和驗證測試等。一般大家會將自己的資料集轉為TFRecord格式以滿足輸入要求,但是另一種思路是修改其原始碼讀取資料集的部分,使其能夠直接從一個包含影象路徑的list中直接讀取。作者後續會推出修改方式。