1. 程式人生 > >【TensorFlow】TFRecord資料集的製作:讀取、顯示及程式碼詳解

【TensorFlow】TFRecord資料集的製作:讀取、顯示及程式碼詳解

在跑通了官網的mnist和cifar10資料之後,筆者嘗試著製作自己的資料集,並儲存,讀入,顯示。 TensorFlow可以支援cifar10的資料格式, 也提供了標準的TFRecord 格式。

 tensorflow 讀取資料, 官網提供了以下三種方法:

1 Feeding: 在tensorflow程式執行的每一步, 用python程式碼線上提供資料;
2 Reader : 在一個計算圖(tf.graph)的開始前,將檔案讀入到流(queue)中;
3 在宣告tf.variable變數或numpy陣列時儲存資料。受限於記憶體大小,適用於資料較小的情況;

在本文,主要介紹第二種方法,利用tf.record標準介面來讀入檔案

準備圖片資料

筆者找了2類狗的圖片, 哈士奇和吉娃娃, 全部 resize成128 * 128大小
如下圖, 儲存地址為D:\Python\data\dog

每類中有10張圖片

現在利用這2 類 20張圖片製作TFRecord檔案

製作TFRECORD檔案

1 先聊一下tfrecord, 這是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,在tensorflow中快速的複製,移動,讀取,儲存 等等..

這裡注意,tfrecord會根據你選擇輸入檔案的類,自動給每一類打上同樣的標籤
如在本例中,只有0,1 兩類

2 先上“製作TFRecord檔案”的程式碼,註釋附詳解

import os

import tensorflow as tf

from PIL import Image  #注意Image,後面會用到

import matplotlib.pyplot as plt

import numpy as np



cwd='D:\Python\data\dog\\'

classes={'husky','chihuahua'} #人為 設定 2 類

writer= tf.python_io.TFRecordWriter("dog_train.tfrecords") #要生成的檔案



for index,name in enumerate(classes):

    class_path=cwd+name+'\\'

    for img_name in os.listdir(class_path):

        img_path=class_path+img_name #每一個圖片的地址



        img=Image.open(img_path)

        img= img.resize((128,128))

        img_raw=img.tobytes()#將圖片轉化為二進位制格式

        example = tf.train.Example(features=tf.train.Features(feature={

            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),

            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))

        })) #example物件對label和image資料進行封裝

        writer.write(example.SerializeToString())  #序列化為字串



writer.close()


tf.train.Example 協議記憶體塊包含了Features欄位,通過feature將圖片的二進位制資料和label進行統一封裝, 然後將example協議記憶體塊轉化為字串, tf.python_io.TFRecordWriter 寫入到TFRecords檔案中。執行完這段程式碼後,會生成dog_train.tfrecords 檔案,如下圖

讀取TFRECORD檔案

在製作完tfrecord檔案後, 將該檔案讀入到資料流中。
程式碼如下

def read_and_decode(filename): # 讀入dog_train.tfrecords

    filename_queue = tf.train.string_input_producer([filename])#生成一個queue佇列



    reader = tf.TFRecordReader()

    _, serialized_example = reader.read(filename_queue)#返回檔名和檔案

    features = tf.parse_single_example(serialized_example,

                                       features={

                                           'label': tf.FixedLenFeature([], tf.int64),

                                           'img_raw' : tf.FixedLenFeature([], tf.string),

                                       })#將image資料和label取出來



    img = tf.decode_raw(features['img_raw'], tf.uint8)

    img = tf.reshape(img, [128, 128, 3])  #reshape為128*128的3通道圖片

    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中丟擲img張量

    label = tf.cast(features['label'], tf.int32) #在流中丟擲label張量

    return img, label

注意,feature的屬性“label”和“img_raw”名稱要和製作時統一 ,返回的img資料和label資料一一對應。返回的img和label是2個 tf 張量,print出來 如下圖

顯示tfrecord格式的圖片

有些時候我們希望檢查分類是否有誤,或者在之後的網路訓練過程中可以監視,輸出圖片,來觀察分類等操作的結果,那麼我們就可以session回話中,將tfrecord的圖片從流中讀取出來,再儲存。 緊跟著一開始的程式碼寫:

filename_queue = tf.train.string_input_producer(["dog_train.tfrecords"]) #讀入流中

reader = tf.TFRecordReader()

_, serialized_example = reader.read(filename_queue)   #返回檔名和檔案

features = tf.parse_single_example(serialized_example,

                                   features={

                                       'label': tf.FixedLenFeature([], tf.int64),

                                       'img_raw' : tf.FixedLenFeature([], tf.string),

                                   })  #取出包含image和label的feature物件

image = tf.decode_raw(features['img_raw'], tf.uint8)

image = tf.reshape(image, [128, 128, 3])

label = tf.cast(features['label'], tf.int32)

with tf.Session() as sess: #開始一個會話

    init_op = tf.initialize_all_variables()

    sess.run(init_op)

    coord=tf.train.Coordinator()

    threads= tf.train.start_queue_runners(coord=coord)

    for i in range(20):

        example, l = sess.run([image,label])#在會話中取出image和label

        img=Image.fromarray(example, 'RGB')#這裡Image是之前提到的

        img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#存下圖片

        print(example, l)

    coord.request_stop()

    coord.join(threads)

程式碼執行完後, 從tfrecord中取出的檔案被儲存了。如下圖:

在這裡我們可以看到,圖片檔名的第一個數字表示在流中的順序(這裡沒有用shuffle), 第二個數字則是 每個圖片的label,吉娃娃都為0,哈士奇都為1。 由此可見,我們一開始製作tfrecord檔案時,圖片分類正確。

轉自:https://www.2cto.com/kf/201702/604326.html