1. 程式人生 > >基於TensorFlow的Cats vs. Dogs(貓狗大戰)實現和詳解(1)

基於TensorFlow的Cats vs. Dogs(貓狗大戰)實現和詳解(1)

2017.5.29

  官方的MNIST例子裡面訓練資料的下載和匯入都是用已經寫好的指令碼完成的,至於裡面實現細節也沒高興去看原始碼,感覺寫得太正式,我這個初學者不好理解。於是在優酷上找到了KevinRush這麼一個播主,裡面的視訊教程講得挺清晰的,於是跟著視訊做了一個貓狗大戰的影象識別程式。

一、貓狗大戰資料集

  Cats vs. Dogs(貓狗大戰)是Kaggle大資料競賽某一年的一道賽題,利用給定的資料集,用演算法實現貓和狗的識別。
  資料集可以從Kaggle官網上下載:

Kaggle官網

  資料集由訓練資料和測試資料組成,訓練資料包含貓和狗各12500張圖片,測試資料包含12500張貓和狗的圖片。
  
訓練資料

  為了以後查閱時不用翻視訊(優酷廣告真心長=.=),這裡把視訊裡的內容重寫一下,也當做是複習。

二、TensorFlow的實現

  我電腦配的環境是win10(64位) + Python3.5.3 + CUDA 8.0 + cudnn 5.1 + tensorflow-gpu 1.1.0 + Pycharm。
  首先在Pycharm上新建Cats_vs_Dogs工程,工程目錄結構為:
這裡寫圖片描述

  • data資料夾下包含testtrain兩個子資料夾,分別用於存放測試資料和訓練資料,從官網上下載的資料直接解壓到相應的資料夾下即可
  • logs資料夾用於存放我們訓練時的模型結構以及訓練引數
  • input_data.py
    負責實現讀取資料,生成批次(batch)
  • model.py負責實現我們的神經網路模型
  • training.py負責實現模型的訓練以及評估

接下來分成資料讀取、模型構造、模型訓練、測試模型四個部分來講。原始碼從文章末尾的連結下載。

1. 訓練資料的讀取——input_data.py

import tensorflow as tf
import numpy as np
import os

  首先是匯入模組。
  tensorflow和numpy不用多說,其中os模組包含作業系統相關的功能,可以處理檔案和目錄這些我們日常手動需要做的操作。因為我們需要獲取test目錄下的檔案,所以要匯入os模組。

# 獲取檔案路徑和標籤
def get_files(file_dir):
    # file_dir: 資料夾路徑
    # return: 亂序後的圖片和標籤

    cats = []
    label_cats = []
    dogs = []
    label_dogs = []
    # 載入資料路徑並寫入標籤值
    for file in os.listdir(file_dir):
        name = file.split(sep='.')
        if name[0] == 'cat':
            cats.append(file_dir + file)
            label_cats.append(0)
        else:
            dogs.append(file_dir + file)
            label_dogs.append(1)
    print("There are %d cats\nThere are %d dogs" % (len(cats), len(dogs)))

    # 打亂檔案順序
    image_list = np.hstack((cats, dogs))
    label_list = np.hstack((label_cats, label_dogs))
    temp = np.array([image_list, label_list])
    temp = temp.transpose()     # 轉置
    np.random.shuffle(temp)

    image_list = list(temp[:, 0])
    label_list = list(temp[:, 1])
    label_list = [int(i) for i in label_list]

    return image_list, label_list

  函式get_files(file_dir)的功能是獲取給定路徑file_dir下的所有的訓練資料(包括圖片和標籤),以list的形式返回。
  由於訓練資料前12500張是貓,後12500張是狗,如果直接按這個順序訓練,訓練效果可能會受影響(我自己猜的),所以需要將順序打亂,至於是讀取資料的時候亂序還是訓練的時候亂序可以自己選擇(視訊裡說在這裡亂序速度比較快)。因為圖片和標籤是一一對應的,所以要整合到一起亂序。
  這裡先用np.hstack()方法將貓和狗圖片和標籤整合到一起,得到image_listlabel_listhstack((a,b))的功能是將a和b以水平的方式連線,比如原來catsdogs是長度為12500的向量,執行了hstack(cats, dogs)後,image_list的長度為25000,同理label_list的長度也為25000。接著將一一對應的image_listlabel_list再合併一次。temp的大小是2×25000,經過轉置(變成25000×2),然後使用np.random.shuffle()方法進行亂序。
  最後從temp中分別取出亂序後的image_listlabel_list列向量,作為函式的返回值。這裡要注意,因為label_list裡面的資料型別是字串型別,所以加上label_list = [int(i) for i in label_list]這麼一行將其轉為int型別。
  

# 生成相同大小的批次
def get_batch(image, label, image_W, image_H, batch_size, capacity):
    # image, label: 要生成batch的影象和標籤list
    # image_W, image_H: 圖片的寬高
    # batch_size: 每個batch有多少張圖片
    # capacity: 佇列容量
    # return: 影象和標籤的batch

    # 將python.list型別轉換成tf能夠識別的格式
    image = tf.cast(image, tf.string)
    label = tf.cast(label, tf.int32)

    # 生成佇列
    input_queue = tf.train.slice_input_producer([image, label])

    image_contents = tf.read_file(input_queue[0])
    label = input_queue[1]
    image = tf.image.decode_jpeg(image_contents, channels=3)

    # 統一圖片大小
    # 視訊方法
    # image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    # 我的方法
    image = tf.image.resize_images(image, [image_H, image_W], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = tf.cast(image, tf.float32)
    # image = tf.image.per_image_standardization(image)   # 標準化資料
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size,
                                              num_threads=64,   # 執行緒
                                              capacity=capacity)

    # 這行多餘?
    # label_batch = tf.reshape(label_batch, [batch_size])

    return image_batch, label_batch

  函式get_batch()用於將圖片分批次,因為一次性將所有25000張圖片載入記憶體不現實也不必要,所以將圖片分成不同批次進行訓練。這裡傳入的imagelabel引數就是函式get_files()返回的image_listlabel_list,是python中的list型別,所以需要將其轉為TensorFlow可以識別的tensor格式。
  這裡使用佇列來獲取資料,因為佇列操作牽扯到執行緒,我自己對這塊也不懂,,所以只從大體上理解了一下,想要系統學習可以去官方文件看看,這裡引用了一張圖解釋。
  
佇列

  我認為大體上可以這麼理解:每次訓練時,從佇列中取一個batch送到網路進行訓練,然後又有新的圖片從訓練庫中注入佇列,這樣迴圈往復。佇列相當於起到了訓練庫到網路模型間資料管道的作用,訓練資料通過佇列送入網路。(我也不確定這麼理解對不對,歡迎指正)

  繼續看程式,我們使用slice_input_producer()來建立一個佇列,將imagelabel放入一個list中當做引數傳給該函式。然後從佇列中取得imagelabel,要注意,用read_file()讀取圖片之後,要按照圖片格式進行解碼。本例程中訓練資料是jpg格式的,所以使用decode_jpeg()解碼器,如果是其他格式,就要用其他解碼器,具體可以從官方API中查詢。注意decode出來的資料型別是uint8,之後模型卷積層裡面conv2d()要求輸入資料為float32型別,所以如果刪掉標準化步驟之後需要進行型別轉換。

  因為訓練庫中圖片大小是不一樣的,所以還需要將圖片裁剪成相同大小(img_Wimg_H)。視訊中是用resize_image_with_crop_or_pad()方法來裁剪圖片,這種方法是從影象中心向四周裁剪,如果圖片超過規定尺寸,最後只會剩中間區域的一部分,可能一隻狗只剩下軀幹,頭都不見了,用這樣的圖片訓練結果肯定會受到影響。所以這裡我稍微改動了一下,使用resize_images()對影象進行縮放,而不是裁剪,採用NEAREST_NEIGHBOR插值方法(其他幾種插值方法出來的結果影象是花的,具體原因不知道)。

  縮放之後視訊中還進行了per_image_standardization (標準化)步驟,但加了這步之後,得到的圖片是花的,雖然各個通道單獨提出來是正常的,三通道一起就不對了,刪了標準化這步結果正常,所以這裡把標準化步驟註釋掉了。

  然後用tf.train.batch()方法獲取batch,還有一種方法是tf.train.shuffle_batch(),因為之前我們已經亂序過了,這裡用普通的batch()就好。視訊中獲取batch後還對label進行了一下reshape()操作,在我看來這步是多餘的,從batch()方法中獲取的大小已經符合我們的要求了,註釋掉也沒什麼影響,能正常獲取圖片。

  最後將得到的image_batchlabel_batch返回。image_batch是一個4D的tensor,[batch, width, height, channels],label_batch是一個1D的tensor,[batch]。

  可以用下面的程式碼測試獲取圖片是否成功,因為之前將圖片轉為float32了,因此這裡imshow()出來的圖片色彩會有點奇怪,因為本來imshow()是顯示uint8型別的資料(灰度值在uint8型別下是0~255,轉為float32後會超出這個範圍,所以色彩有點奇怪),不過這不影響後面模型的訓練。

# TEST
import matplotlib.pyplot as plt

BATCH_SIZE = 2
CAPACITY = 256
IMG_W = 208
IMG_H = 208

train_dir = "data\\train\\"
image_list, label_list = get_files(train_dir)
image_batch, label_batch = get_batch(image_list, label_list, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)

with tf.Session() as sess:
    i = 0
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    try:
        while not coord.should_stop() and i < 1:
            img, label = sess.run([image_batch, label_batch])

            for j in np.arange(BATCH_SIZE):
                print("label: %d" % label[j])
                plt.imshow(img[j, :, :, :])
                plt.show()
            i += 1
    except tf.errors.OutOfRangeError:
        print("done!")
    finally:
        coord.request_stop()
    coord.join(threads)

  鑑於篇幅原因,其他部分見下一篇部落格。

參考