1. 程式人生 > >TensorFlow中資料讀取之tfrecords

TensorFlow中資料讀取之tfrecords

關於Tensorflow讀取資料,官網給出了三種方法:

  • 供給資料(Feeding): 在TensorFlow程式執行的每一步, 讓Python程式碼來供給資料。
  • 從檔案讀取資料: 在TensorFlow圖的起始, 讓一個輸入管線從檔案中讀取資料。
  • 預載入資料: 在TensorFlow圖中定義常量或變數來儲存所有資料(僅適用於資料量比較小的情況)。

對於資料量較小而言,可能一般選擇直接將資料載入進記憶體,然後再分batch輸入網路進行訓練(tip:使用這種方法時,結合yield 使用更為簡潔,大家自己嘗試一下吧,我就不贅述了)。但是,如果資料量較大,這樣的方法就不適用了,因為太耗記憶體,所以這時最好使用tensorflow提供的佇列queue,也就是第二種方法 從檔案讀取資料。對於一些特定的讀取,比如csv檔案格式,官網有相關的描述,在這兒我介紹一種比較通用,高效的讀取方法(官網介紹的少),即使用tensorflow內定標準格式——TFRecords

TFRecords
TFRecords其實是一種二進位制檔案,雖然它不如其他格式好理解,但是它能更好的利用記憶體,更方便複製和移動,並且不需要單獨的標籤檔案。

TFRecords檔案包含了tf.train.Example 協議記憶體塊(protocol buffer)(協議記憶體塊包含了欄位 Features)。我們可以寫一段程式碼獲取你的資料, 將資料填入到Example協議記憶體塊(protocol buffer),將協議記憶體塊序列化為一個字串, 並且通過tf.python_io.TFRecordWriter 寫入到TFRecords檔案。

從TFRecords檔案中讀取資料, 可以使用tf.TFRecordReader

tf.parse_single_example解析器。這個操作可以將Example協議記憶體塊(protocol buffer)解析為張量。

生成TFRecords檔案

存入TFRecords檔案需要資料先存入名為example的protocol buffer,然後將其serialize成為string才能寫入。example中包含features,用於描述資料型別:bytes,float,int64。

我們使用tf.train.Example來定義我們要填入的資料格式,然後使用tf.python_io.TFRecordWriter來寫入。

writer = tf.python_io.TFRecordWriter(out_name)
        
#對每條資料分別獲得文件,問題,答案三個值,並將相應單詞轉化為索引 #呼叫Example和Features函式將資料格式化儲存起來。注意Features傳入的引數應該是一個字典,方便後續讀資料時的操作 example = tf.train.Example( features = tf.train.Features( feature = { 'document': tf.train.Feature( int64_list=tf.train.Int64List(value=document)), 'query': tf.train.Feature( int64_list=tf.train.Int64List(value=query)), 'answer': tf.train.Feature( int64_list=tf.train.Int64List(value=answer)) })) #寫資料 serialized = example.SerializeToString() writer.write(serialized)

也可以用extend的方式:

example = tf.train.Example()
example.features.feature["context"].int64_list.value.extend(context_transformed) 
example.features.feature["utterance"].int64_list.value.extend(utterance_transformed) example.features.feature["context_len"].int64_list.value.extend([context_len]) example.features.feature["utterance_len"].int64_list.value.extend([utterance_len]) writer = tf.python_io.TFRecordWriter(output_filename) writer.write(example.SerializeToString()) writer.close()

讀取tfrecords檔案

 

首先用tf.train.string_input_producer讀取tfrecords檔案的list建立FIFO序列,可以申明num_epoches和shuffle引數表示需要讀取資料的次數以及時候將tfrecords檔案讀入順序打亂,然後定義TFRecordReader讀取上面的序列返回下一個record,用tf.parse_single_example對讀取到TFRecords檔案進行解碼,根據儲存的serialize example和feature字典返回feature所對應的值。此時獲得的值都是string,需要進一步解碼為所需的資料型別。把影象資料的string reshape成原始影象後可以進行preprocessing操作。此外,還可以通過tf.train.batch或者tf.train.shuffle_batch將影象生成batch序列。

由於tf.train函式會在graph中增加tf.train.QueueRunner類,而這些類有一系列的enqueue選項使一個佇列在一個執行緒裡執行。為了填充佇列就需要用tf.train.start_queue_runners來為所有graph中的queue runner啟動執行緒,而為了管理這些執行緒就需要一個tf.train.Coordinator來在合適的時候終止這些執行緒。


因為在讀取資料之後我們可能還會進行一些額外的操作,使我們的資料格式滿足模型輸入,所以這裡會引入一些額外的函式來實現我們的目的。這裡介紹幾個個人感覺較重要常用的函式。不過還是推薦到官網API去查,或者有某種需求的時候到Stack Overflow上面搜一搜,一般都能找到滿足自己需求的函式。
1,string_input_producer(
string_tensor,
num_epochs=None,
shuffle=True,
seed=None,
capacity=32,
shared_name=None,
name=None,
cancel_op=None
)

其輸出是一個輸入管道的佇列,這裡需要注意的引數是num_epochs和shuffle。對於每個epoch其會將所有的檔案新增到檔案隊列當中,如果設定shuffle,則會對檔案順序進行打亂。其對檔案進行均勻取樣,而不會導致上下采樣。

2,shuffle_batch(
tensors,
batch_size,
capacity,
min_after_dequeue,
num_threads=1,
seed=None,
enqueue_many=False,
shapes=None,
allow_smaller_final_batch=False,
shared_name=None,
name=None
)

產生隨機打亂之後的batch資料

3,sparse_ops.serialize_sparse(sp_input, name=None): 返回一個字串的3-vector(1-D的tensor),分別表示索引、值、shape

4,deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None): 將多個稀疏的serialized_sparse合併成一個

基本的,一個Example中包含Features,Features裡包含Feature(這裡沒s)的字典。最後,Feature裡包含有一個 FloatList, 或者ByteList,或者Int64List

就這樣,我們把相關的資訊都存到了一個檔案中,而且讀取也很方便。

簡單的讀取小例子

for serialized_example in tf.python_io.tf_record_iterator("train.tfrecords"):
    example = tf.train.Example()
    example.ParseFromString(serialized_example)
    context = example.features.feature['context'].int64_list.value
    utterance = example.features.feature['utterance'].int64_list.value

使用佇列讀取

一旦生成了TFRecords檔案,為了高效地讀取資料,TF中使用佇列(queue)讀取資料。

def read_and_decode(filename):
    #根據檔名生成一個佇列
    filename_queue = tf.train.string_input_producer([filename])

    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)   #返回檔名和檔案
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.int64),
                                           'img_raw' : tf.FixedLenFeature([], tf.string),
                                       })

    img = tf.decode_raw(features['img_raw'], tf.uint8)
    img = tf.reshape(img, [224, 224, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
    label = tf.cast(features['label'], tf.int32)

    return img, label

之後我們可以在訓練的時候這樣使用

img, label = read_and_decode("train.tfrecords")

#使用shuffle_batch可以隨機打亂輸入
img_batch, label_batch = tf.train.shuffle_batch([img, label],
                                                batch_size=30, capacity=2000,
                                                min_after_dequeue=1000)
init = tf.initialize_all_variables()

with tf.Session() as sess:
    sess.run(init)
   # 這是填充佇列的指令,如果不執行程式會等在佇列檔案的讀取處無法執行 coord
= tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  
for i in range(3): val, l= sess.run([img_batch, label_batch]) #我們也可以根據需要對val, l進行處理 #l = to_categorical(l, 12) print(val.shape, l)

注意:

第一,tensorflow裡的graph能夠記住狀態(state),這使得TFRecordReader能夠記住tfrecord的位置,並且始終能返回下一個。而這就要求我們在使用之前,必須初始化整個graph,這裡我們使用了函式tf.initialize_all_variables()來進行初始化。

第二,tensorflow中的佇列和普通的佇列差不多,不過它裡面的operation和tensor都是符號型的(symbolic),在呼叫sess.run()時才執行。

第三, TFRecordReader會一直彈出佇列中檔案的名字,直到佇列為空。

總結

  1. 生成tfrecord檔案
  2. 定義record reader解析tfrecord檔案
  3. 構造一個批生成器(batcher
  4. 構建其他的操作
  5. 初始化所有的操作
  6. 啟動QueueRunner

參考:

https://blog.csdn.net/u012759136/article/details/52232266

https://blog.csdn.net/liuchonge/article/details/73649251