TensorFlow學習筆記(9) TFRecord 輸入資料格式
TF提供了一種統一的格式來儲存資料,這個格式就是TFRecord。TFRecord檔案中的資料都是通過tf.train.Example Protocol Buffer的格式儲存的。tf.train.Example中包括一個從屬性名稱到取值的字典。其中屬性名稱為一個字串,取值為字串、實數列表或者整數列表。下面為一個具體的樣例程式將MNIST輸入資料轉化為TFRecord格式。
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import numpy as np #生成整數型的屬性 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) #生成字串型的屬性 def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) mnist = input_data.read_data_sets('MNIST_data', dtype=tf.uint8, one_hot=True) images = mnist.train.images labels = mnist.train.labels pixels = images.shape[1] num_examples = mnist.train.num_examples filename = 'mnist.tfrecords' #建立一個writer來寫TFRecord檔案 writer = tf.python_io.TFRecordWriter(filename) for index in range(num_examples): #將影象矩陣轉化為一個字串 images_raw = images[index].tobytes() example = tf.train.Example(features=tf.train.Features(feature={ 'pixels': _int64_feature(pixels), 'label': _int64_feature(np.argmax(labels[index])), 'image_raw': _bytes_feature(images_raw)})) writer.write(example.SerializeToString()) writer.close()
以下程式給出瞭如何讀取TFRecord檔案中的資料。
import tensorflow as tf #建立一個reader來讀取TFRecord檔案中的樣例 reader = tf.TFRecordReader() #建立一個佇列來維護輸入檔案列表 filename_queue = tf.train.string_input_producer('mnist.tfrecords') #從檔案中讀出一個樣例 _, serialized_example = reader.read(filename_queue) #解析讀入的一個樣例 features = tf.parse_single_example(serialized_example, features={ #tf.FixedLenFeature解析結果為tensor 'image_raw': tf.FixedLenFeature([], tf.string), 'pixels': tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([],tf.int64) }) #tf.decode_raw可以將字串解析成影象對應的畫素陣列 images = c(features['image_raw'], tf.unit8) labels = tf.cast(features['label'], tf.int32) pixels = tf.cast(features['pixels'], tf.int32) sess = tf.Session() #多執行緒。。 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) for i in range(10): image, label, pixel = sess.run([images, labels, pixels])
多執行緒輸入資料處理框架
為了避免影象預處理成為神經網路模型訓練效率的瓶頸,TF提供了一套多執行緒處理輸入資料的框架。
經典的輸入資料處理流程為:指定原始資料的檔案列表>建立檔案列表佇列>從檔案中讀取資料>資料預處理>整理成Batch作為神經網路的輸入。佇列不僅是一種資料結構,更提供了多執行緒機制,佇列也是TF中多執行緒輸入資料處理框架的基礎。比如多個執行緒可以同時向一個佇列中寫元素,或者同時讀取一個佇列中的元素。
佇列與多執行緒
佇列和變數都是計算圖上有狀態的節點。通過賦值修改變數的取值;通過Enqueue、EnqueueMany、Dequeue函式來修改佇列狀態。以下程式展示瞭如何使用這些函式來操作一個佇列。
import tensorflow as tf
#建立一個先進先出佇列,指定佇列中最多可以儲存兩個元素,並指定型別為整數。
q = tf.FIFOQueue(2, 'int32')
#使用enqueue_many函式來初始化佇列元素,在使用佇列之前要明確呼叫這個初始化過程
init = q.enqueue_many(([0, 10],))
#使用Dequeue函式將佇列中第一個元素出佇列,這個元素將被存在變數x
x = q.dequeue()
y = x + 1
#重新加入佇列
q_inc = q.enqueue([y])
with tf.Session() as sess:
#執行初始化佇列的操作
sess.run(init)
for _ in range(5):
v, _ = sess.run([x, q_inc])
print(v)
在TF中提供了FIFOQueue和RandomShuffleQueue兩種佇列。在上面的程式中展示了FIFOQueue佇列。而RandomShuffleQueue會將佇列中的元素打亂,每次enqueue_many操作得到的是從當前佇列中隨機選擇的一個元素。
TF提供了tf.Coordinator和tf.QueueRunner兩個類來完成多執行緒協同的功能。
tf.Coordinator主要用於協同多個執行緒一起停止,提供了should_stop,request_stop和join三個函式。啟動的程序只有當should_stop函式為True時則退出。每一個啟動的程序通過呼叫request_stop函式來通知其他執行緒退出。
import tensorflow as tf
import numpy as np
import threading
import time
#線上程中執行的程式,這個程式每隔1s判斷是否需要停止列印自己的id
def MyLoop(coord, worker_id):
while not coord.should_stop():
if np.random.rand() < 0.1:
print('Stoping from id:%d' % worker_id)
coord.request_stop()
else:
print('Working on id:%d'% worker_id)
time.sleep(1)
#宣告一個tf.train.Coordinator()類
coord = tf.train.Coordinator()
#建立五個執行緒
threads = [threading.Thread(target=MyLoop, args=(coord,i,))for i in range(5)]
#啟動所有的執行緒
for t in threads: t.start()
#等待所有執行緒退出
coord.join(threads)
tf.QueueRunner主要用於啟動多個執行緒來操作同一個佇列,這些執行緒可以通過tf.Coordinator來進行統一管理。比如,
import tensorflow as tf
#宣告佇列,100個元素,型別實數
queue = tf.FIFOQueue(100, 'float')
#定義佇列入隊操作
enqueue_op = queue.enqueue([tf.random_normal([1])])
#使用tf.train.QueueRunner建立多個執行緒的入隊操作
#第一個引數為被操作的佇列,第二個表示需要啟動五個執行緒,每個執行緒都是enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op]*5)
#將定義過的qr加入tf計算圖指定的集合
#若沒有指定集合,則加入預設的集合tf.GraphKeys,QUEUE_RUNNERS
tf.train.add_queue_runner(qr)
#定義出佇列操作
out_tensor = queue.dequeue()
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for _ in range(3): print(sess.run(out_tensor)[0])
coord.request_stop()
coord.join(threads)