1. 程式人生 > >tensorflow 16:資料讀取(以cifar10_input.py為例)

tensorflow 16:資料讀取(以cifar10_input.py為例)

資料讀取概述

TensorFlow程式讀取資料一共有3種方法:

  • 供給資料(Feeding): 在TensorFlow程式執行的每一步, 讓Python程式碼來供給資料。
  • 從檔案讀取資料: 在TensorFlow圖的起始, 讓一個輸入管線從檔案中讀取資料。
  • 預載入資料: 在TensorFlow圖中定義常量或變數來儲存所有資料(僅適用於資料量比較小的情況)。

目前我用過的主要是第一種,就是提供feed_dict來向計算圖喂資料。第三種比較少用。

本篇部落格主要講第二種。

從檔案讀取的流水線

下圖來自文末的參考資料《tensorflow資料讀取》。 在這裡插入圖片描述

注意這個流水線有兩個佇列。一個是檔案佇列,由檔名生成。生成的時候可以指定亂序,長度可以長於檔案個數(這時佇列內就會有重複)。

第二個佇列是讀出的樣本佇列。

兩個佇列之間的部分由多個讀取執行緒組成,每個執行緒包括reader、decoder、與處理組成。

注意:樣本佇列最終以計算圖節點的形式接入計算圖,計算圖根據依賴自動去獲取資料,不用手動餵了。

程式碼檔案說明

包含以下幾個檔案:

檔名 說明
構建計算圖,包括inference、train、loss,同時返回了流水線讀取資料的label和image節點。
cifar10_input.py 構建從檔案讀取資料的流水線
cifar10_input_test.py 測試cifar10_input.py中的reader
cifar10_train.py 訓練程式碼
cifar10_multi_gpu_train.py 多GPU訓練程式碼
cifar10_eval.py 評估訓練程式碼

cifar10_input.py: inputs分解

cifar10_input.py對外提供了兩個介面:inputs和distorted_inputs。區別就是後者回對影象做一些隨機翻轉、裁剪、亮度調整等處理,相當於資料增廣,前者原樣返回。

def inputs(eval_data, data_dir, batch_size):
  if not eval_data:
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin'
% i) for i in xrange(1, 6)] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN else: filenames = [os.path.join(data_dir, 'test_batch.bin')] num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) with tf.name_scope('input'): # 1. 建立檔名佇列 filename_queue = tf.train.string_input_producer(filenames) # 2. 建立reader和decoder,增加圖片預處理 read_input = read_cifar10(filename_queue) reshaped_image = tf.cast(read_input.uint8image, tf.float32) height = IMAGE_SIZE width = IMAGE_SIZE # 將原本32*32的圖片,轉換為24*24 resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, height, width) # Subtract off the mean and divide by the variance of the pixels. float_image = tf.image.per_image_standardization(resized_image) # Set the shapes of tensors. float_image.set_shape([height, width, 3]) read_input.label.set_shape([1]) # Ensure that the random shuffling has good mixing properties. min_fraction_of_examples_in_queue = 0.4 min_queue_examples = int(num_examples_per_epoch * min_fraction_of_examples_in_queue) # 3. 建立佇列,按batch獲取image和label return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size, shuffle=False)

這個函式可以分為三部分:

  1. 建立檔名佇列,對上本文開頭圖片左邊的部分
  2. 建立reader和decoder,增加圖片預處理,對應開頭圖片兩個佇列之間的部分.這裡有呼叫tf.image.per_image_standardization對圖片歸一化。
  3. 建立佇列,按batch獲取image和label, 對應開頭圖片最右側的佇列

最終要的是兩處函式呼叫,即呼叫read_cifar10()和_generate_image_and_label_batch()

先看read_cifar10(),這個函式用於建立reader和decoder。

def read_cifar10(filename_queue):
  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()

  # 定義圖片格式.
  label_bytes = 1  # 2 for CIFAR-100
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth
  # Every record consists of a label followed by the image, with a
  # fixed number of bytes for each.
  record_bytes = label_bytes + image_bytes

  # Read a record, getting filenames from the filename_queue.  No
  # header or footer in the CIFAR-10 format, so we leave header_bytes
  # and footer_bytes at their default of 0.
  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)

  # Convert from a string to a vector of uint8 that is record_bytes long.
  record_bytes = tf.decode_raw(value, tf.uint8)

  # The first bytes represent the label, which we convert from uint8->int32.
  result.label = tf.cast(
      tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)

  # The remaining bytes after the label represent the image, which we reshape
  # from [depth * height * width] to [depth, height, width].
  depth_major = tf.reshape(
      tf.strided_slice(record_bytes, [label_bytes],
                       [label_bytes + image_bytes]),
      [result.depth, result.height, result.width])
  # Convert from [depth, height, width] to [height, width, depth].
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])

  return result

注意這裡用的reader是tf.FixedLengthRecordReader,用的decoder是tf.decode_raw。如果是別的格式的檔案(如cvs),需要選擇別的reader和decoder。這裡返回的result各成員都是tensor,不是普通檔案,需要執行計算圖才能獲得實際內容。每次讀取一個樣本,有意cifar10的檔案是多個圖片在一個bin檔案裡,下次會從上次讀取的位置接著讀。

另外一個重要的函式是_generate_image_and_label_batch(),它的任務主要是建立按batch獲取圖片的佇列,需要上面建立好的result作為輸入。

def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size, shuffle):
  """Construct a queued batch of images and labels.

  Args:
    image: 3-D Tensor of [height, width, 3] of type.float32.
    label: 1-D Tensor of type.int32
    min_queue_examples: int32, minimum number of samples to retain
      in the queue that provides of batches of examples.
    batch_size: Number of images per batch.
    shuffle: boolean indicating whether to use a shuffling queue.

  Returns:
    images: Images. 4D tensor of [batch_size, height, width, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  """
  # Create a queue that shuffles the examples, and then
  # read 'batch_size' images + labels from the example queue.
  num_preprocess_threads = 16
  if shuffle:
    images, label_batch = tf.train.shuffle_batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples)
  else:
    images, label_batch = tf.train.batch(
        [image, label],
        batch_size=batch_size,
        num_threads=num_preprocess_threads,
        capacity=min_queue_examples + 3 * batch_size)

  # Display the training images in the visualizer.
  tf.summary.image('images', images)

  return images, tf.reshape(label_batch, [batch_size])

根據是否打亂順序,這個函式會選擇呼叫 tf.train.shuffle_batch()還是tf.train.batch() 返回兩個tensor,一個是images和labels,數量就是傳入的batch_size控制的。

cifar10_input.py:distorted_inputs分解

def distorted_inputs(data_dir, batch_size):
  """Construct distorted input for CIFAR training using the Reader ops.

  Args:
    data_dir: Path to the CIFAR-10 data directory.
    batch_size: Number of images per batch.

  Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
  """
  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in xrange(1, 6)]
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  # 1. 建立檔名佇列.
  filename_queue = tf.train.string_input_producer(filenames)

  with tf.name_scope('data_augmentation'):
    # 2. 建立reader和decoder,增加圖片預處理
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # Image processing for training the network. Note the many random
    # distortions applied to the image.

    # 隨機裁剪
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

    # 隨機左右翻轉
    distorted_image = tf.image.random_flip_left_right(distorted_image)

    # 隨機調整亮度和對比度
    distorted_image = tf.image.random_brightness(distorted_image,
                                                 max_delta=63)
    distorted_image = tf.image.random_contrast(distorted_image,
                                               lower=0.2, upper=1.8)

    # 標準化(減去均值畫素除以標準差).
    float_image = tf.image.per_image_standardization(distorted_image)

    # Set the shapes of tensors.
    float_image.set_shape([height, width, 3])
    read_input.label.set_shape([1])

    # Ensure that the random shuffling has good mixing properties.
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                             min_fraction_of_examples_in_queue)
    print ('Filling queue with %d CIFAR images before starting to train. '
           'This will take a few minutes.' % min_queue_examples)

  # 3. 建立佇列,按batch獲取image和label
  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size,
                                         shuffle=True)

可以看到,這個函式的整體流程和input基本一致,只是多了在decoder之後的預處理,對影象做了很多轉換,起到資料增廣的目的。

參考資料