1. 程式人生 > >TensorFlow TFRecord資料集的生成與顯示

TensorFlow TFRecord資料集的生成與顯示

TFRecord

  TensorFlow提供了TFRecord的格式來統一儲存資料,TFRecord格式是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,在tensorflow中快速的複製,移動,讀取,儲存 等等。
  TFRecords檔案包含了tf.train.Example 協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位 Features)。我們可以寫一段程式碼獲取你的資料, 將資料填入到Example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為一個字串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords檔案。
從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個操作可以將Example協議記憶體塊(protocol buffer)解析為Tensor。

Image to TFRecord

  首先我們使用TensorFlow提供的Flowers資料集做這個實驗,資料集在我本地的路徑為:

這裡寫圖片描述
這是一個五分類的資料,以類別的形式組織資料,這非常符合我們自己組織資料集的習慣。其中一個分類中大概有700張左右的圖片:
這裡寫圖片描述

現在我們就把上面的資料製作出TFRecord,在這裡需要說明下,TFRecord的生成要注意兩點:
1.很多時候,我們的圖片尺寸並不是統一的,所以在生成的TFRecord中需要包含影象的width和height這兩個資訊,這樣在解析圖片的時候,我們才能把二進位制的資料重新reshape成圖片;
2.TensorFlow官方的建議是一個TFRecord中最好圖片的數量為1000張左右,這個很好理解,如果我們有上萬張圖片,卻只打成一個包,這樣是很不利於多執行緒讀取的。所以我們需要根據影象資料自動去選擇到底打包幾個TFRecord出來。

我們可以用下面的程式碼實現這兩個目的:

import os 
import tensorflow as tf 
from PIL import Image  

#圖片路徑
cwd = 'F:\\flowersdata\\trainimages\\'
#檔案路徑
filepath = 'F:\\flowersdata\\tfrecord\\'
#存放圖片個數
bestnum = 1000
#第幾個圖片
num = 0
#第幾個TFRecord檔案
recordfilenum = 0
#類別
classes=['daisy',
         'dandelion',
         'roses'
, 'sunflowers', 'tulips'] #tfrecords格式檔名 ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum) writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename) #類別和路徑 for index,name in enumerate(classes): print(index) print(name) class_path=cwd+name+'\\' for img_name in os.listdir(class_path): num=num+1 if num>bestnum: num = 1 recordfilenum = recordfilenum + 1 #tfrecords格式檔名 ftrecordfilename = ("traindata.tfrecords-%.3d" % recordfilenum) writer= tf.python_io.TFRecordWriter(filepath+ftrecordfilename) #print('路徑',class_path) #print('第幾個圖片:',num) #print('檔案的個數',recordfilenum) #print('圖片名:',img_name) img_path = class_path+img_name #每一個圖片的地址 img=Image.open(img_path,'r') size = img.size print(size[1],size[0]) print(size) #print(img.mode) img_raw=img.tobytes()#將圖片轉化為二進位制格式 example = tf.train.Example( features=tf.train.Features(feature={ 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), 'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])), 'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]])) })) writer.write(example.SerializeToString()) #序列化為字串 writer.close()

在上面的程式碼中,我們規定了一個TFRecord中只放1000張圖:

bestnum = 1000

並且將一張圖的4個資訊打包到TFRecord中,分別是:

example = tf.train.Example(
             features=tf.train.Features(feature={
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
            'img_width':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[0]])),
            'img_height':tf.train.Feature(int64_list=tf.train.Int64List(value=[size[1]]))
        })) 

這裡寫圖片描述

TFRecord to Image

在上面我們打包了四個TFRecord檔案,下面我們把這些資料讀取並顯示出來,看看製作的效果,這個過程很大一部分是和TensorFlow組織batch是一樣的了。

import tensorflow as tf 
from PIL import Image  
import matplotlib.pyplot as plt

#寫入圖片路徑
swd = 'F:\\flowersdata\\show\\'
#TFRecord檔案路徑
data_path = 'F:\\flowersdata\\tfrecord\\traindata.tfrecords-003'
# 獲取檔名列表
data_files = tf.gfile.Glob(data_path)
print(data_files)
# 檔名列表生成器

filename_queue = tf.train.string_input_producer(data_files,shuffle=True) 
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)   #返回檔名和檔案
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([], tf.int64),
                                       'img_raw' : tf.FixedLenFeature([], tf.string),
                                       'img_width': tf.FixedLenFeature([], tf.int64),
                                       'img_height': tf.FixedLenFeature([], tf.int64),
                                   })  #取出包含image和label的feature物件
#tf.decode_raw可以將字串解析成影象對應的畫素陣列
image = tf.decode_raw(features['img_raw'], tf.uint8)
height = tf.cast(features['img_height'],tf.int32)
width = tf.cast(features['img_width'],tf.int32)
label = tf.cast(features['label'], tf.int32)
channel = 3
image = tf.reshape(image, [height,width,channel])


with tf.Session() as sess: #開始一個會話
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    #啟動多執行緒
    coord=tf.train.Coordinator()
    threads= tf.train.start_queue_runners(coord=coord)
    for i in range(15):
        #image_down = np.asarray(image_down.eval(), dtype='uint8')
        plt.imshow(image.eval())
        plt.show()
        single,l = sess.run([image,label])#在會話中取出image和label
        img=Image.fromarray(single, 'RGB')#這裡Image是之前提到的
        img.save(swd+str(i)+'_''Label_'+str(l)+'.jpg')#存下圖片
        #print(single,l)
    coord.request_stop()
    coord.join(threads)

注意:
1.我們在使用reshape去將二進位制資料重新變成圖片的時候,用的就是之前打包進去的width和height,否則程式會出錯;

image = tf.reshape(image, [height,width,channel])

2.在圖片儲存時的命名方式為:mun_Label_calss id

這裡寫圖片描述

3.程式碼也可以實時show出當前的圖片:

這裡寫圖片描述

完整程式碼也可以點選這裡下載。