1. 程式人生 > >Tensorflow中使用tfrecord,佇列方式讀取資料

Tensorflow中使用tfrecord,佇列方式讀取資料

標準TensorFlow格式

      有一種儲存記錄的方法可以允許你講任意的資料轉換為TensorFlow所支援的格式, 這種方法可以使TensorFlow的資料集更容易與網路應用架構相匹配。這種建議的方法就是使用TFRecords檔案,TFRecords檔案包含了tf.train.Example 協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位 Features)。你可以寫一段程式碼獲取你的資料, 將資料填入到Example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為一個字串, 並且通過tf.python_io.TFRecordWriter class寫入到TFRecords檔案。
     從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReader的tf.parse_single_example解析器。這個parse_single_example操作可以將Example協議記憶體塊(protocol buffer)解析為張量。

     tfecord檔案中的資料是通過tf.train.Example Protocol Buffer的格式儲存的,下面是tf.train.Example的定義:

message Example {
 Features features = 1;
};

message Features{
 map<string,Feature> featrue = 1;
};

message Feature{
    oneof kind{
        BytesList bytes_list = 1;
        FloatList float_list = 2;
        Int64List int64_list = 3;
    }
};

從上述程式碼可以看到,ft.train.Example 的資料結構相對簡潔。tf.train.Example中包含了一個從屬性名稱到取值的字典,其中屬性名稱為一個字串,屬性的取值可以為字串(BytesList ),實數列表(FloatList )或整數列表(Int64List )。例如我們可以將解碼前的圖片作為字串,影象對應的類別標號作為整數列表。

測試例子

使用queue讀取圖片資料方法的大致思路分為三步:
1、根據資料集的具體儲存情況生成一個txt清單,清單上記載了每一張圖片的儲存地址還有一些相關資訊(如標籤、大小之類的)
2、根據第一步的清單記錄,讀取資料和資訊,並將這些資料和資訊按照一定的格式寫成Tensorflow的專用檔案格式(.tfrecords)

3、從.tfrecords檔案中批量的讀取資料供給模型使用

資料清單的生成

根據資料的儲存情況生成的資料清單,不同的情況寫的程式碼肯定也是不一樣的,這裡根據我的具體情況說一下過程和程式
我的資料儲存地址為:/Users/zhuxiaoxiansheng/Desktop/doc/SICA_data/YaleB

具體情況如下:

這裡第一張圖片的的Class01表示的是第一個類別,00000表示的是第一個類別裡的第一張,生成清單的程式如下:

##相關庫函式匯入
import os
import cv2 as cv
import tensorflow as tf 
from PIL import Image
import matplotlib.pyplot as plt
def getTrianList():
    root_dir = "/Users/zhuxiaoxiansheng/Desktop/doc/SICA_data/YaleB"  #資料儲存資料夾地址
    with open('/Users/zhuxiaoxiansheng/Desktop'+"/Yaledata.txt","w") as f:    #txt檔案生成地址
        for file in os.listdir(root_dir):
            if len(file) == 23:                     #圖片名長為23個位元組,避免讀入其他的檔案
                f.write(root_dir+'/'+file+" "+ file[11:13] +"\n")   #file[11:13]表示類別編號

生成的清單檔案是這樣的

生成tfrecords檔案

在得到txt清單檔案以後,根據這份檔案就可以進入流程式的步驟了,首先我們需要生成.tfrecords檔案,程式碼如下

def load_file(example_list_file):   #從清單中讀取地址和類別編號,這裡的輸入是清單儲存地址
   lines = np.genfromtxt(example_list_file,delimiter=" ",dtype=[('col1', 'S120'), ('col2', 'i8')])
   examples = []
   labels = []
   for example,label in lines:
       examples.append(example)
       labels.append(label)
   return np.asarray(examples),np.asarray(labels),len(lines)   

def trans2tfRecord(trainFile,output_dir):    #生成tfrecords檔案
    _examples,_labels,examples_num = load_file(trainFile)
    filename = output_dir + '.tfrecords'
    writer = tf.python_io.TFRecordWriter(filename)
    for i,[example,label] in enumerate(zip(_examples,_labels)):
        example = example.decode("UTF-8")
        image = cv.imread(example)
        image = cv.resize(image,(192,168))    #這裡的格式需要注意,一定要儘量保證圖片的大小一致
        image_raw = image.tostring()          #將圖片矩陣轉化為字串格式
        example = tf.train.Example(features=tf.train.Features(feature={
                'image_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))                        
                }))
        writer.write(example.SerializeToString()) 
    writer.close()     #寫入完成,關閉指標
    return filename    #返回檔案地址

這裡生成的是.tfrecords不好開啟,就不展示了

從tfrecords檔案中讀取資料

設定從tfrecords檔案中讀取檔案方式的函式如下:

def read_tfRecord(file_tfRecord):     #輸入是.tfrecords檔案地址
    queue = tf.train.string_input_producer([file_tfRecord])
    reader = tf.TFRecordReader()
    _,serialized_example = reader.read(queue)
    features = tf.parse_single_example(
            serialized_example,
            features={
          'image_raw':tf.FixedLenFeature([], tf.string),   
          'label':tf.FixedLenFeature([], tf.int64)
                    }
            )
    image = tf.decode_raw(features['image_raw'],tf.uint8)
    image = tf.reshape(image,[192,168,3])
    image = tf.cast(image, tf.float32)
    image = tf.image.per_image_standardization(image)
    label = tf.cast(features['label'], tf.int64)   這裡設定了讀取資訊的格式
    return image,label

測試程式碼

上面就是主要的程式碼,其中特別要注意的就是以下兩句,非常重要:

coord=tf.train.Coordinator() #建立一個協調器,管理執行緒
threads=tf.train.start_queue_runners(coord=coord) #啟動QueueRunner, 此時檔名佇列已經進隊

這兩句實現的功能就是建立執行緒並使用QueueRunner物件來提取資料。簡單來說:使用tf.train函式新增QueueRunner到tensorflow中。在執行任何訓練步驟之前,需要呼叫tf.train.start_queue_runners函式,否則tensorflow將一直掛起。

tf.train.start_queue_runners 這個函式將會啟動輸入管道的執行緒,填充樣本到佇列中,以便出隊操作可以從佇列中拿到樣本。這種情況下最好配合使用一個tf.train.Coordinator,這樣可以在發生錯誤的情況下正確地關閉這些執行緒。如果你對訓練迭代數做了限制,那麼需要使用一個訓練迭代數計數器,並且需要被初始化。

if __name__ == '__main__':
    getTrianList()
    dataroad = "/Users/zhuxiaoxiansheng/Desktop/Yaledata.txt"
    outputdir = "/Users/zhuxiaoxiansheng/Desktop/Yaledata"

    trainroad = trans2tfRecord(dataroad,outputdir)
    traindata,trainlabel = read_tfRecord(trainroad)
    image_batch,label_batch = tf.train.shuffle_batch([traindata,trainlabel],
                                            batch_size=100,capacity=2000,min_after_dequeue = 1000) 

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())
        sess.run(tf.global_variables_initializer())  
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess,coord = coord)
        train_steps = 10  

        try:  
            while not coord.should_stop():  # 如果執行緒應該停止則返回True  
                example,label = sess.run([image_batch,label_batch])  
                print(example.shape,label)  

                train_steps -= 1  
                print(train_steps)  
                if train_steps <= 0:  
                    coord.request_stop()    # 請求該執行緒停止  

        except tf.errors.OutOfRangeError:  
            print ('Done training -- epoch limit reached')  
        finally:  
            # When done, ask the threads to stop. 請求該執行緒停止  
            coord.request_stop()  
        # And wait for them to actually do it. 等待被指定的執行緒終止  
        coord.join(threads)  

如果成功的話會有下面的輸出(輸出結果就截自己的圖吧):

]。