Tensorflow中使用tfrecord,佇列方式讀取資料
標準TensorFlow格式
有一種儲存記錄的方法可以允許你講任意的資料轉換為TensorFlow所支援的格式, 這種方法可以使TensorFlow的資料集更容易與網路應用架構相匹配。這種建議的方法就是使用TFRecords檔案,TFRecords檔案包含了tf.train.Example 協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位 Features)。你可以寫一段程式碼獲取你的資料, 將資料填入到Example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為一個字串, 並且通過tf.python_io.TFRecordWriter class寫入到TFRecords檔案。從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個parse_single_example操作可以將Example協議記憶體塊(protocol buffer)解析為張量。
tfecord檔案中的資料是通過tf.train.Example Protocol Buffer的格式儲存的,下面是tf.train.Example的定義:
message Example {
Features features = 1;
};
message Features{
map<string,Feature> featrue = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
從上述程式碼可以看到,ft.train.Example 的資料結構相對簡潔。tf.train.Example中包含了一個從屬性名稱到取值的字典,其中屬性名稱為一個字串,屬性的取值可以為字串(BytesList ),實數列表(FloatList )或整數列表(Int64List )。例如我們可以將解碼前的圖片作為字串,影象對應的類別標號作為整數列表。
測試例子
使用queue讀取圖片資料方法的大致思路分為三步:
1、根據資料集的具體儲存情況生成一個txt清單,清單上記載了每一張圖片的儲存地址還有一些相關資訊(如標籤、大小之類的)
2、根據第一步的清單記錄,讀取資料和資訊,並將這些資料和資訊按照一定的格式寫成Tensorflow的專用檔案格式(.tfrecords)
3、從.tfrecords檔案中批量的讀取資料供給模型使用
資料清單的生成
根據資料的儲存情況生成的資料清單,不同的情況寫的程式碼肯定也是不一樣的,這裡根據我的具體情況說一下過程和程式我的資料儲存地址為:/Users/zhuxiaoxiansheng/Desktop/doc/SICA_data/YaleB
具體情況如下:
這裡第一張圖片的的Class01表示的是第一個類別,00000表示的是第一個類別裡的第一張,生成清單的程式如下:
##相關庫函式匯入
import os
import cv2 as cv
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
def getTrianList():
root_dir = "/Users/zhuxiaoxiansheng/Desktop/doc/SICA_data/YaleB" #資料儲存資料夾地址
with open('/Users/zhuxiaoxiansheng/Desktop'+"/Yaledata.txt","w") as f: #txt檔案生成地址
for file in os.listdir(root_dir):
if len(file) == 23: #圖片名長為23個位元組,避免讀入其他的檔案
f.write(root_dir+'/'+file+" "+ file[11:13] +"\n") #file[11:13]表示類別編號
生成的清單檔案是這樣的
生成tfrecords檔案
在得到txt清單檔案以後,根據這份檔案就可以進入流程式的步驟了,首先我們需要生成.tfrecords檔案,程式碼如下
def load_file(example_list_file): #從清單中讀取地址和類別編號,這裡的輸入是清單儲存地址
lines = np.genfromtxt(example_list_file,delimiter=" ",dtype=[('col1', 'S120'), ('col2', 'i8')])
examples = []
labels = []
for example,label in lines:
examples.append(example)
labels.append(label)
return np.asarray(examples),np.asarray(labels),len(lines)
def trans2tfRecord(trainFile,output_dir): #生成tfrecords檔案
_examples,_labels,examples_num = load_file(trainFile)
filename = output_dir + '.tfrecords'
writer = tf.python_io.TFRecordWriter(filename)
for i,[example,label] in enumerate(zip(_examples,_labels)):
example = example.decode("UTF-8")
image = cv.imread(example)
image = cv.resize(image,(192,168)) #這裡的格式需要注意,一定要儘量保證圖片的大小一致
image_raw = image.tostring() #將圖片矩陣轉化為字串格式
example = tf.train.Example(features=tf.train.Features(feature={
'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
}))
writer.write(example.SerializeToString())
writer.close() #寫入完成,關閉指標
return filename #返回檔案地址
這裡生成的是.tfrecords不好開啟,就不展示了
從tfrecords檔案中讀取資料
設定從tfrecords檔案中讀取檔案方式的函式如下:
def read_tfRecord(file_tfRecord): #輸入是.tfrecords檔案地址
queue = tf.train.string_input_producer([file_tfRecord])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw':tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([], tf.int64)
}
)
image = tf.decode_raw(features['image_raw'],tf.uint8)
image = tf.reshape(image,[192,168,3])
image = tf.cast(image, tf.float32)
image = tf.image.per_image_standardization(image)
label = tf.cast(features['label'], tf.int64) 這裡設定了讀取資訊的格式
return image,label
測試程式碼
上面就是主要的程式碼,其中特別要注意的就是以下兩句,非常重要:
coord=tf.train.Coordinator() #建立一個協調器,管理執行緒
threads=tf.train.start_queue_runners(coord=coord) #啟動QueueRunner, 此時檔名佇列已經進隊
這兩句實現的功能就是建立執行緒並使用QueueRunner物件來提取資料。簡單來說:使用tf.train函式新增QueueRunner到tensorflow中。在執行任何訓練步驟之前,需要呼叫tf.train.start_queue_runners函式,否則tensorflow將一直掛起。
tf.train.start_queue_runners 這個函式將會啟動輸入管道的執行緒,填充樣本到佇列中,以便出隊操作可以從佇列中拿到樣本。這種情況下最好配合使用一個tf.train.Coordinator,這樣可以在發生錯誤的情況下正確地關閉這些執行緒。如果你對訓練迭代數做了限制,那麼需要使用一個訓練迭代數計數器,並且需要被初始化。if __name__ == '__main__':
getTrianList()
dataroad = "/Users/zhuxiaoxiansheng/Desktop/Yaledata.txt"
outputdir = "/Users/zhuxiaoxiansheng/Desktop/Yaledata"
trainroad = trans2tfRecord(dataroad,outputdir)
traindata,trainlabel = read_tfRecord(trainroad)
image_batch,label_batch = tf.train.shuffle_batch([traindata,trainlabel],
batch_size=100,capacity=2000,min_after_dequeue = 1000)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord = coord)
train_steps = 10
try:
while not coord.should_stop(): # 如果執行緒應該停止則返回True
example,label = sess.run([image_batch,label_batch])
print(example.shape,label)
train_steps -= 1
print(train_steps)
if train_steps <= 0:
coord.request_stop() # 請求該執行緒停止
except tf.errors.OutOfRangeError:
print ('Done training -- epoch limit reached')
finally:
# When done, ask the threads to stop. 請求該執行緒停止
coord.request_stop()
# And wait for them to actually do it. 等待被指定的執行緒終止
coord.join(threads)
如果成功的話會有下面的輸出(輸出結果就截自己的圖吧):
]。