1. 程式人生 > >TensorFlow資料讀取方法

TensorFlow資料讀取方法

轉自:http://honggang.io/2016/08/19/tensorflow-data-reading/

引言

Tensorflow的資料讀取有三種方式:

  • Preloaded data: 預載入資料
  • Feeding: Python產生資料,再把資料餵給後端。
  • Reading from file: 從檔案中直接讀取

這三種有讀取方式有什麼區別呢? 我們首先要知道TensorFlow(TF)是怎麼樣工作的。

TF的核心是用C++寫的,這樣的好處是執行快,缺點是呼叫不靈活。而Python恰好相反,所以結合兩種語言的優勢。涉及計算的核心運算元和執行框架是用C++寫的,並提供API給Python。Python呼叫這些API,設計訓練模型(Graph),再將設計好的Graph給後端去執行。簡而言之,Python的角色是Design,C++是Run。

Preload與Feeding

Preload

12345678910 import tensorflow as tf# 設計Graphx1 = tf.constant([2, 3, 4])x2 = tf.constant([4, 0, 1])y = tf.add(x1, x2)# 開啟一個session --> 計算ywith tf.Session() as sess: print sess.run(y)

在設計Graph的時候,x1和x2就被定義成了兩個有值的列表,在計算y的時候直接取x1和x2的值。

Feeding

1234567891011121314 import
tensorflow as tf# 設計Graphx1 = tf.placeholder(tf.int16)x2 = tf.placeholder(tf.int16)y = tf.add(x1, x2)# 用Python產生資料li1 = [2, 3, 4]li2 = [4, 0, 1]# 開啟一個session --> 喂資料 --> 計算ywith tf.Session() as sess: print sess.run(y, feed_dict={x1: li1, x2: li2})

在這裡x1, x2只是佔位符,沒有具體的值,那麼執行的時候去哪取值呢?這時候就要用到sess.run()

中的feed_dict引數,將Python產生的資料餵給後端,並計算y。

兩種方法的區別

Preload:
將資料直接內嵌到Graph中,再把Graph傳入Session中執行。當資料量比較大時,Graph的傳輸會遇到效率問題。
Feeding:
用佔位符替代資料,待執行的時候填充資料。

Reading From File

前兩種方法很方便,但是遇到大型資料的時候就會很吃力,即使是Feeding,中間環節的增加也是不小的開銷,比如資料型別轉換等等。最優的方案就是在Graph定義好檔案讀取的方法,讓TF自己去從檔案中讀取資料,並解碼成可使用的樣本集。

AnimatedFileQueues

在上圖中,首先由一個單執行緒把檔名堆入佇列,兩個Reader同時從佇列中取檔名並讀取資料,Decoder將讀出的資料解碼後堆入樣本佇列,最後單個或批量取出樣本(圖中沒有展示樣本出列)。我們這裡通過三段程式碼逐步實現上圖的資料流,這裡我們不使用隨機,讓結果更清晰。

檔案準備

1234567 $ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv$ cat A.csvAlpha1,A1Alpha2,A2Alpha3,A3

單個Reader,單個樣本

1234567891011121314151617181920212223242526272829303132 import tensorflow as tf# 生成一個先入先出佇列和一個QueueRunnerfilenames = ['A.csv', 'B.csv', 'C.csv']filename_queue = tf.train.string_input_producer(filenames, shuffle=False)# 定義Readerreader = tf.TextLineReader()key, value = reader.read(filename_queue)# 定義Decoderexample, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])# 執行Graphwith tf.Session() as sess: coord = tf.train.Coordinator() #建立一個協調器,管理執行緒 threads = tf.train.start_queue_runners(coord=coord) #啟動QueueRunner, 此時檔名佇列已經進隊。 for i in range(10): print example.eval() #取樣本的時候,一個Reader先從檔名佇列中取出檔名,讀出資料,Decoder解析後進入樣本佇列。 coord.request_stop() coord.join(threads)# outptAlpha1Alpha2Alpha3Bee1Bee2Bee3Sea1Sea2Sea3Alpha1

單個Reader,多個樣本

12345678910111213141516171819202122232425262728293031323334 import tensorflow as tffilenames = ['A.csv', 'B.csv', 'C.csv']filename_queue = tf.train.string_input_producer(filenames, shuffle=False)reader = tf.TextLineReader()key, value = reader.read(filename_queue)example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']])# 使用tf.train.batch()會多加了一個樣本佇列和一個QueueRunner。Decoder解後資料會進入這個佇列,再批量出隊。# 雖然這裡只有一個Reader,但可以設定多執行緒,相應增加執行緒數會提高讀取速度,但並不是執行緒越多越好。example_batch, label_batch = tf.train.batch( [example, label], batch_size=5)with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): print example_batch.eval() coord.request_stop() coord.join(threads)# output# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']

多Reader,多個樣本

1234567891011121314151617181920212223242526272829303132333435 import tensorflow as tffilenames = ['A.csv', 'B.csv', 'C.csv']filename_queue = tf.train.string_input_producer(filenames, shuffle=False)reader = tf.TextLineReader()key, value = reader.read(filename_queue)record_defaults = [['null'], ['null']]example_list = [tf.decode_csv(value, record_defaults=record_defaults) for _ in range(2)] # Reader設定為2# 使用tf.train.batch_join(),可以使用多個reader,並行讀取資料。每個Reader使用一個執行緒。example_batch, label_batch = tf.train.batch_join( example_list, batch_size=5)with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) for i in range(10): print example_batch.eval() coord.request_stop() coord.join(threads)# output# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']# ['Sea2' 'Sea3' 'Alpha1' 'Alpha2' 'Alpha3']# ['Bee1' 'Bee2' 'Bee3' 'Sea1' 'Sea2']# ['Sea3' 'Alpha1' 'Alpha2' 'Alpha3' 'Bee1']# ['Bee2' 'Bee3' 'Sea1' 'Sea2' 'Sea3']# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']

tf.train.batchtf.train.shuffle_batch函式是單個Reader讀取,但是可以多執行緒。tf.train.batch_jointf.train.shuffle_batch_join可設定多Reader讀取,每個Reader使用一個執行緒。至於兩種方法的效率,單Reader時,2個執行緒就達到了速度的極限。多Reader時,2個Reader就達到了極限。所以並不是執行緒越多越快,甚至更多的執行緒反而會使效率下降。

迭代控制

123456789101112131415161718192021222324252627282930313233343536 filenames = ['A.csv', 'B.csv', 'C.csv']filename_queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3) # num_epoch: 設定迭代數reader = tf.TextLineReader()key, value = reader.read(filename_queue)record_defaults = [['null'], ['null']]example_list = [tf.decode_csv(value, record_defaults=record_defaults) for _ in range(2)]example_batch, label_batch = tf.train.batch_join( example_list, batch_size=5)init_local_op = tf.initialize_local_variables()with tf.Session() as sess: sess.run(init_local_op) # 初始化本地變數 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: while not coord.should_stop(): print example_batch.eval() except tf.errors.OutOfRangeError: print('Epochs Complete!') finally: coord.request_stop() coord.join(threads) coord.request_stop() coord.join(threads)# output# ['Alpha1' 'Alpha2' 'Alpha3' 'Bee1' 'Bee2']# ['Bee3' 'Sea1' 'Sea2' 'Sea3' 'Alpha1']# ['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3']# ['Sea1' 'Sea2' 'Sea3' 'Alpha1' 'Alpha2']# ['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1']# Epochs Complete!

在迭代控制中,記得新增tf.initialize_local_variables(),官網教程沒有說明,但是如果不初始化,執行就會報錯。