1. 程式人生 > >【tensorflow入門教程二】資料集製作:使用TFRecords製作資料集並使用inceptionv3進行訓練

【tensorflow入門教程二】資料集製作:使用TFRecords製作資料集並使用inceptionv3進行訓練

這篇文章中,我們將探討深度學習中最基本的問題,影象分類中的資料集以及標籤的製作;以及使用Inceptionv3網路對其進行訓練。

PS:文末附博文配套程式碼以及資料集原圖的下載。

先上一張最後的訓練結果圖:

17flowers資料集

17flowers資料集包含有17種不同的花的圖片,每個種類的花都含有80張圖片,圖片的尺寸不唯一,但是都在500x500左右,所有這些一共組成了1360張圖片。該篇博文要做的就是使用tensorflow將其做成tfrecords格式的資料集檔案。


製作TFRecords資料集

首先定位到我們的原圖片的目錄,並使用陣列儲存所有類別。

cwd = 'D:\py project/tensorflow-tfrecord\jpg\\'
classes = {'daffodil', 'snowdrop', 'lilyvalley', 'bluebell', 'crocus', 'iris', 'tigerlily', 'tulip', 'fritiuary',
           'sunflower', 'daisy', 'coltsfoot', 'dandelion', 'cowslip', 'buttercup', 'windflower', 'pansy'}  # 花為 設定 17 類

這裡使用tf.python_io.TFRecordWriter的方式將所有圖片資料寫入到tfrecords檔案。

def createdata():
    filename="flower_train.tfrecords"      #要生成的檔名以及地址,不指定絕對地址的話就是在建立在工程目錄下
    writer = tf.python_io.TFRecordWriter(filename)  # 使用該函式建立一個tfrecord檔案
    height = 299    #將圖片儲存成為299x299的尺寸,方便進行之後的訓練
    width = 299
    for index, name in enumerate(classes):    #index即為花的類別的索引,若當前值index=0, name= 'corslip',則在標籤y=0時即表示這張圖屬於corslip
        class_path = cwd + name + '\\'    #定位到每一個花的類別目錄
        for img_name in os.listdir(class_path): # 以list的方式顯示目錄下的各個資料夾
            img_path = class_path + img_name  # 每一個圖片的地址
            img = Image.open(img_path)    # 匯入Image 包,開啟圖片
            img = img.resize((height, width))    
            img_raw = img.tobytes()  # 將圖片轉化為二進位制格式
            example = tf.train.Example(features=tf.train.Features(feature={    #寫的時候標籤類的資料形式為int64,圖片類的資料形式為Bytes
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 
                'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
            }))  # example物件對label和image資料進行封裝
            writer.write(example.SerializeToString())  # 序列化為字串
    writer.close()

執行上述程式碼之後我們將在當前工程的目錄下得到一個flower_train.tfrecords檔案。

接下來對我們的Tfrecords檔案進行讀取並解析成能使用的資料。

要對tfrecords檔案進行讀取,首先需要使用tf.train.string_input_producer建立一個佇列,並使用tf.TFRecordReader()讀取tfrecords檔案;之後使用 pasr_single_example對序列化的資料解析。

def read_and_decode(filename, batch_size):  # 讀取tfrecords
    filename_queue = tf.train.string_input_producer([filename])  # 生成一個queue佇列
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)  # 返回檔名和檔案
    features = tf.parse_single_example(serialized_example,  #對序列化資料進行解析
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw': tf.FixedLenFeature([], tf.string),
                                       })  # 將image資料和label取出來

    img = tf.decode_raw(features['img_raw'], tf.uint8) #將圖片解析成uint8格式的資料
    img = tf.reshape(img, [299, 299, 3])  # 解碼後需要reshape為299*299的3通道圖片
    img = tf.cast(img, tf.float32) * (1. / 255)  # 將tensor資料轉化為float32格式,後面的*(1./255)是必須的,不然生成的圖片會反相。
    label = tf.cast(features['label'], tf.int64)  # 將label標籤轉化為int64格式
    label = tf.one_hot(label, 17)   #對標籤做one hot處理:假如共有4個類,若標籤為3,做one hot之後則為[0 0 0 1],若標籤為0,則[1 0 0 0]
    # img_batch, label_batch = tf.train.batch([img,label],batch_size,1,50)    #按序輸出
    img_batch, label_batch = tf.train.shuffle_batch([img,label],batch_size,500,100)     #打亂排序輸出batch
    return img_batch, label_batch
要讀取每個record是固定位元組數的二進位制檔案,需要tf.FixedLengthRecordReader與tf.decode_raw操作一起使用。該decode_raw操作從string型別轉換為uint8 tensor。其中的tf.train.shuffle_batch定義如下:
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
                  num_threads=1, seed=None, enqueue_many=False, shapes=None,
                  allow_smaller_final_batch=False, shared_name=None, name=None):
batch_size為佇列一次輸出的資料大小,capactiy為佇列中儲存的最大資料數量,min_after_dequeue為出隊後佇列中的元素最小數量。其中capacity的值須大於min_after_dequeue。num_threads為該函式執行的執行緒數,即使用幾個執行緒從佇列中取資料。

測試一下以上程式碼能不能再次讀取我們的圖片:

if __name__ == "__main__":
    createdata()
    init_op = tf.global_variables_initializer()
    image, label = read_and_decode("flower_train.tfrecords", 32)    #該處得到的為tensor,需要sess.run才能得到實際的資料
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()        #從佇列中取資料需要先建立一個Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)    #並建立執行緒開始從佇列中讀取資料
        for i in range(32):
            example, l = sess.run([image, label])  # 取出image和label
            plt.imshow(example[i, :, :, :])
            plt.show()
            print(l[i])
            print(example.shape)
        coord.request_stop()    #結束佇列
        coord.join(threads)

在print處打個斷點,可以看到如下結果:



至此,資料集的製作及解析便處理完畢。

接下來使用inceptionv3網路對其進行訓練。inception v3的網路結構及構建方法請檢視我之前的部落格(即開頭給的連結),這裡給出主體函式部分的程式碼(loss的定義以及optimizer的定義):


其得到的最終訓練結果如下所示:


本篇文章的配套程式碼請點選下面的連結進行下載:

本篇博文大致如此,下篇文章見