1. 程式人生 > >TensorFlow學習筆記(9) TFRecord 輸入資料格式

TensorFlow學習筆記(9) TFRecord 輸入資料格式

TF提供了一種統一的格式來儲存資料,這個格式就是TFRecord。TFRecord檔案中的資料都是通過tf.train.Example Protocol Buffer的格式儲存的。tf.train.Example中包括一個從屬性名稱到取值的字典。其中屬性名稱為一個字串,取值為字串、實數列表或者整數列表。下面為一個具體的樣例程式將MNIST輸入資料轉化為TFRecord格式。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np

#生成整數型的屬性
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

#生成字串型的屬性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

mnist = input_data.read_data_sets('MNIST_data', dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples

filename = 'mnist.tfrecords'
#建立一個writer來寫TFRecord檔案
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
    #將影象矩陣轉化為一個字串
    images_raw = images[index].tobytes()
    example = tf.train.Example(features=tf.train.Features(feature={
        'pixels': _int64_feature(pixels),
        'label': _int64_feature(np.argmax(labels[index])),
        'image_raw': _bytes_feature(images_raw)}))
    writer.write(example.SerializeToString())
writer.close()

以下程式給出瞭如何讀取TFRecord檔案中的資料。

import tensorflow as tf

#建立一個reader來讀取TFRecord檔案中的樣例
reader = tf.TFRecordReader()
#建立一個佇列來維護輸入檔案列表
filename_queue = tf.train.string_input_producer('mnist.tfrecords')
#從檔案中讀出一個樣例
_, serialized_example = reader.read(filename_queue)
#解析讀入的一個樣例
features = tf.parse_single_example(serialized_example, features={
    #tf.FixedLenFeature解析結果為tensor
    'image_raw': tf.FixedLenFeature([], tf.string),
    'pixels': tf.FixedLenFeature([], tf.int64),
    'label': tf.FixedLenFeature([],tf.int64)
})

#tf.decode_raw可以將字串解析成影象對應的畫素陣列
images = c(features['image_raw'], tf.unit8)
labels = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)

sess = tf.Session()
#多執行緒。。
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(10):
    image, label, pixel = sess.run([images, labels, pixels])

多執行緒輸入資料處理框架

為了避免影象預處理成為神經網路模型訓練效率的瓶頸,TF提供了一套多執行緒處理輸入資料的框架。

經典的輸入資料處理流程為:指定原始資料的檔案列表>建立檔案列表佇列>從檔案中讀取資料>資料預處理>整理成Batch作為神經網路的輸入。佇列不僅是一種資料結構,更提供了多執行緒機制,佇列也是TF中多執行緒輸入資料處理框架的基礎。比如多個執行緒可以同時向一個佇列中寫元素,或者同時讀取一個佇列中的元素。

佇列與多執行緒

佇列和變數都是計算圖上有狀態的節點。通過賦值修改變數的取值;通過Enqueue、EnqueueMany、Dequeue函式來修改佇列狀態。以下程式展示瞭如何使用這些函式來操作一個佇列。

import tensorflow as tf

#建立一個先進先出佇列,指定佇列中最多可以儲存兩個元素,並指定型別為整數。
q = tf.FIFOQueue(2, 'int32')
#使用enqueue_many函式來初始化佇列元素,在使用佇列之前要明確呼叫這個初始化過程
init = q.enqueue_many(([0, 10],))
#使用Dequeue函式將佇列中第一個元素出佇列,這個元素將被存在變數x
x = q.dequeue()
y = x + 1
#重新加入佇列
q_inc = q.enqueue([y])

with tf.Session() as sess:
    #執行初始化佇列的操作
    sess.run(init)
    for _ in range(5):
        v, _ = sess.run([x, q_inc])
        print(v)

在TF中提供了FIFOQueue和RandomShuffleQueue兩種佇列。在上面的程式中展示了FIFOQueue佇列。而RandomShuffleQueue會將佇列中的元素打亂,每次enqueue_many操作得到的是從當前佇列中隨機選擇的一個元素。

TF提供了tf.Coordinator和tf.QueueRunner兩個類來完成多執行緒協同的功能。

tf.Coordinator主要用於協同多個執行緒一起停止,提供了should_stop,request_stop和join三個函式。啟動的程序只有當should_stop函式為True時則退出。每一個啟動的程序通過呼叫request_stop函式來通知其他執行緒退出。

import tensorflow as tf
import numpy as np
import threading
import time

#線上程中執行的程式,這個程式每隔1s判斷是否需要停止列印自己的id
def MyLoop(coord, worker_id):
    while not coord.should_stop():
        if np.random.rand() < 0.1:
            print('Stoping from id:%d' % worker_id)
            coord.request_stop()
        else:
            print('Working on id:%d'% worker_id)
        time.sleep(1)

#宣告一個tf.train.Coordinator()類
coord = tf.train.Coordinator()
#建立五個執行緒
threads = [threading.Thread(target=MyLoop, args=(coord,i,))for i in range(5)]
#啟動所有的執行緒
for t in threads: t.start()
#等待所有執行緒退出
coord.join(threads)

tf.QueueRunner主要用於啟動多個執行緒來操作同一個佇列,這些執行緒可以通過tf.Coordinator來進行統一管理。比如,

import tensorflow as tf

#宣告佇列,100個元素,型別實數
queue = tf.FIFOQueue(100, 'float')
#定義佇列入隊操作
enqueue_op = queue.enqueue([tf.random_normal([1])])
#使用tf.train.QueueRunner建立多個執行緒的入隊操作
#第一個引數為被操作的佇列,第二個表示需要啟動五個執行緒,每個執行緒都是enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op]*5)

#將定義過的qr加入tf計算圖指定的集合
#若沒有指定集合,則加入預設的集合tf.GraphKeys,QUEUE_RUNNERS
tf.train.add_queue_runner(qr)
#定義出佇列操作
out_tensor = queue.dequeue()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for _ in range(3): print(sess.run(out_tensor)[0])

    coord.request_stop()
    coord.join(threads)