1. 程式人生 > >tensorflow實現inception Net資料增強

tensorflow實現inception Net資料增強

在CNN中,為了增大資料量避免模型的過擬合,通常都會對訓練資料做資料增強處理,這篇文章主要介紹在Inception Net中是如何做資料增強的,tensorflow官方通過slim已經實現了VGG、Inception、LeNet網路的資料增強的,官網連結如下:https://github.com/tensorflow/models/tree/master/research/slim/preprocessing,Inception Net資料增強主要包括以下幾個部分:

1、將圖片的畫素縮放到[0,1)

#將圖片的畫素值縮放到[0,1)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)

2、隨機裁剪圖片

  • image:一個3維的圖片tensor,資料的取值範圍在[0,1],即表示已經做了歸一化後的圖片
  • bbox:box的邊框,[ymin,xmin,ymax,xmax],預設是[0,0,1,1]表示使用的是整張圖片
  • min_object_covered:在隨機裁剪圖片的時候必須要包括box邊框的比例
  • aspect_ratio_range:隨機裁剪的圖片,寬/高的比例需要滿足的範圍
  • area_range:隨機裁剪的圖片需要佔圖片的比例
  • max_attempts:隨機裁剪,嘗試的最多次數,超過最大嘗試次數返回整張圖片

返回的是一張裁剪之後的圖片和隨機裁剪所選的區域,後面會通過tensorflow對這個區域進行標註,裁剪的圖片就是從這個區域中選擇的。

def distorted_bounding_box_crop(image,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
        tf.shape(image),
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    cropped_image = tf.slice(image, bbox_begin, bbox_size)
    return cropped_image, distort_bbox

3、隨機變化圖片的引數(翻轉、亮度、飽和度、色度、對比度)

在隨機變化圖片引數的時候,tensorflow提供了一種fast_mode模型,從名字上理解就是快速模型,在fast_mode模型中沒有做比較耗時處理的色度變換和對比度變換,由於引數變換的順序也會影響最終生成的圖片,在變換引數的時候也還加入了不同變化順序。

a、隨機水平翻轉

    # 隨機水平翻轉
    distorted_image = tf.image.random_flip_left_right(distorted_image)

b、亮度、飽和度、色度、對比度的隨機變換

    if fast_mode:
      if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
      else:
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
    else:
      if color_ordering == 0:
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
      elif color_ordering == 1:
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
      elif color_ordering == 2:
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
      elif color_ordering == 3:
        image = tf.image.random_hue(image, max_delta=0.2)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
        image = tf.image.random_brightness(image, max_delta=32. / 255.)
      else:
        raise ValueError('color_ordering must be in [0, 3]')

4、將圖片的畫素轉換到[-1,1]區間內

    distorted_image = tf.subtract(distorted_image, 0.5)#減去0.5
    distorted_image = tf.multiply(distorted_image, 2.0)#乘以2

第一步將圖片轉換到[0,1]區間內,通過除以255,通過最後兩步可以將圖片轉換到[-1,1]區間內。將圖片轉換到[0,1]區間,其實就是做了一個0(最小值)/255(最大值)的變化,將圖片轉換到[-1,1]區間相當於做了128(最小值)/128(最大值),實驗證明[-1,1]區間比[0,1]的準確率要高一些。

注意:對圖片的畫素區間做了轉換之後,在預測圖片類標的時候,也需要對圖片的畫素區間進行相同的縮放。

5、測試

import tensorflow as tf
from inception_preprocessing import preprocess_image
import matplotlib.pyplot as plt
plt.rcParams["font.sans-serif"]=["SimHei"]

if __name__ == "__main__":
    #將圖片轉換為bytes資料
    image_raw = tf.gfile.FastGFile("cat.4.jpg",mode="rb").read()
    #將圖片轉為tensor
    img = tf.image.decode_jpeg(image_raw)
    with tf.Session() as sess:
        plt.subplot(221)
        plt.imshow(sess.run(img))
        plt.title("原圖")
        for i in range(3):
            #進行圖片預處理的時候,需要指定圖片的大小
            pre_img = preprocess_image(img,227,227,is_training=True,fast_mode=True)
            plt.subplot(2,2,i+2)
            #將圖片的畫素轉換到[0,255]內
            pre_img = tf.convert_to_tensor(sess.run(pre_img)*128+128,dtype=tf.uint8)
            plt.imshow(sess.run(pre_img))
            plt.title("預處理圖片%d"%(i+1))
        plt.show()