1. 程式人生 > >TensorFlow 製作自己的TFRecord資料集 讀取、顯示及程式碼詳解

TensorFlow 製作自己的TFRecord資料集 讀取、顯示及程式碼詳解

準備圖片資料

筆者找了2類狗的圖片, 哈士奇和吉娃娃, 全部 resize成128 * 128大小
如下圖, 儲存地址為D:\Python\data\dog

每類中有10張圖片

 

現在利用這2 類 20張圖片製作TFRecord檔案

製作TFRECORD檔案

1 先聊一下tfrecord, 這是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,在tensorflow中快速的複製,移動,讀取,儲存 等等..

這裡注意,tfrecord會根據你選擇輸入檔案的類,自動給每一類打上同樣的標籤
如在本例中,只有0,1 兩類

2 先上“製作TFRecord檔案”的程式碼,註釋附詳解


import os 
import tensorflow as tf 
from PIL import Image  #注意Image,後面會用到
import matplotlib.pyplot as plt 
import numpy as np
 
cwd='D:\Python\data\dog\\'
classes={'husky','chihuahua'} #人為 設定 2 類
writer= tf.python_io.TFRecordWriter("dog_train.tfrecords") #要生成的檔案
 
for index,name in enumerate(classes):
    class_path=cwd+name+'\\'
    for img_name in os.listdir(class_path): 
        img_path=class_path+img_name #每一個圖片的地址
 
        img=Image.open(img_path)
        img= img.resize((128,128))
        img_raw=img.tobytes()#將圖片轉化為二進位制格式
        example = tf.train.Example(features=tf.train.Features(feature={
            "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
        })) #example物件對label和image資料進行封裝
        writer.write(example.SerializeToString())  #序列化為字串
 
writer.close()

執行完這段程式碼後,會生成dog_train.tfrecords 檔案,如下圖

tf.train.Example 協議記憶體塊包含了Features欄位,通過feature將圖片的二進位制資料和label進行統一封裝, 然後將example協議記憶體塊轉化為字串, tf.python_io.TFRecordWriter 寫入到TFRecords檔案中。

 

讀取TFRECORD檔案

在製作完tfrecord檔案後, 將該檔案讀入到資料流中。
程式碼如下

def read_and_decode(filename): # 讀入dog_train.tfrecords
    filename_queue = tf.train.string_input_producer([filename])#生成一個queue佇列
 
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)#返回檔名和檔案
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })#將image資料和label取出來
 
    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [128, 128, 3])  #reshape為128*128的3通道圖片
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #在流中丟擲img張量
    label = tf.cast(features['label'], tf.int32) #在流中丟擲label張量
    return img, label

注意,feature的屬性“label”和“img_raw”名稱要和製作時統一 ,返回的img資料和label資料一一對應。返回的img和label是2個 tf 張量,print出來 如下圖

 

顯示tfrecord格式的圖片

有些時候我們希望檢查分類是否有誤,或者在之後的網路訓練過程中可以監視,輸出圖片,來觀察分類等操作的結果,那麼我們就可以session回話中,將tfrecord的圖片從流中讀取出來,再儲存。 緊跟著一開始的程式碼寫:

filename_queue = tf.train.string_input_producer(["dog_train.tfrecords"]) #讀入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回檔名和檔案
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                   })  #取出包含image和label的feature物件
image = tf.decode_raw(features['img_raw'], tf.uint8)
image = tf.reshape(image, [128, 128, 3])
label = tf.cast(features['label'], tf.int32)
with tf.Session() as sess: #開始一個會話
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(20):
        example, l = sess.run([image,label])#在會話中取出image和label
        img=Image.fromarray(example, 'RGB')#這裡Image是之前提到的
        img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')#存下圖片
        print(example, l)
    coord.request_stop()
    coord.join(threads)

程式碼執行完後, 從tfrecord中取出的檔案被儲存了。如下圖:

 

在這裡我們可以看到,圖片檔名的第一個數字表示在流中的順序(這裡沒有用shuffle), 第二個數字則是 每個圖片的label,吉娃娃都為0,哈士奇都為1。 由此可見,我們一開始製作tfrecord檔案時,圖片分類正確。