1. 程式人生 > >TensorFlow多線程輸入數據處理框架(三)——組合訓練數據

TensorFlow多線程輸入數據處理框架(三)——組合訓練數據

code lte 函數 auth cast desc 結構 save pca

參考書

《TensorFlow:實戰Google深度學習框架》(第2版)

通過TensorFlow提供的tf.train.batch和tf.train.shuffle_batch函數來將單個的樣例組織成batch的形式輸出。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: sample_data_deal2.py
@time: 2019/2/4 11:15
@desc: 通過TensorFlow提供的tf.train.batch和tf.train.shuffle_batch函數來將單個的樣例組織成batch的形式輸出。
""" import tensorflow as tf # 使用tf.train.match_filenames_once函數獲取文件列表 files = tf.train.match_filenames_once(./data.tfrecords-*) # 通過tf.train.string_input_producer函數創建輸入隊列,輸入隊列中的文件列表為 # tf.train.match_filenames_once函數獲取的文件列表。這裏將shuffle參數設為False # 來避免隨機打亂讀文件的順序。但一般在解決真實問題時,會將shuffle參數設置為True filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 如前面所示讀取並解析一個樣本 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ i: tf.FixedLenFeature([], tf.int64), j: tf.FixedLenFeature([], tf.int64), } ) # 使用前面的方法讀取並解析得到的樣例。這裏假設Example結構中i表示一個樣例的特征向量
# 比如一張圖像的像素矩陣。而j表示該樣例對應的標簽。 example, label = features[i], features[j] # 一個batch中樣例的個數。 batch_size = 3 # 組合樣例的隊列中最多可以存儲的樣例個數。這個隊列如果太大,那麽需要占用很多內存資源; # 如果太小,那麽出隊操作可能會因為沒有數據而被阻礙(block),從而導致訓練效率降低。 # 一般來說這個隊列的大小會和每一個batch的大小相關,下面一行代碼給出了設置隊列大小的一種方式。 capacity = 1000 + 3 * batch_size # 使用tf.train.batch函數來組合樣例。[example, label]參數給出了需要組合的元素, # 一般example和label分別代表訓練樣本和這個樣本對應的正確標簽。batch_size參數給出了 # 每個batch中樣例的個數。capacity給出了隊列的最大容量。每當隊列長度等於容量時, # TensorFlow將暫停入隊操作,而只是等待元素出隊。當元素個數小於容量時, # TensorFlow將自動重新啟動入隊操作。 # example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity) # 使用tf.train.shuffle_batch函數來組合樣例。tf.train.shuffle_batch函數的參數 # 大部分都和tf.train.batch函數相似,但是min_after_dequeue參數是tf.train.shuffle_batch # 函數特有的。min_after_dequeue參數限制了出隊時隊列中元素的最少個數。當隊列中元素太少時, # 隨機打亂樣例順序的作用就不大了。所以tf.train.shuffle_batch函數提供了限制出隊時最少元素的個數 # 來保證隨機打亂順序的作用。當出隊函數被調用但是隊列中元素不夠時,出隊操作將等待更多的元素入隊 # 才會完成。如果min_after_dequeue參數被設定,capacity也應該相應調整來滿足性能需求。 example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue=30) with tf.Session() as sess: tf.local_variables_initializer().run() tf.global_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) # 獲取並打印組合之後的樣例。在真實問題中,這個輸出一般會作為神經網絡的輸入。 for i in range(2): cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch]) print(cur_example_batch, cur_label_batch) coord.request_stop() coord.join(threads)

運行結果:

1. 使用tf.train.batch函數來組合樣例

技術分享圖片技術分享圖片?

2. 使用tf.train.shuffle_batch函數來組合樣例

技術分享圖片技術分享圖片?

3. 兩個函數的區別

tf.train.batch函數不會隨機打亂順序,而tf.train.shuffle_batch會隨機打亂順序。

TensorFlow多線程輸入數據處理框架(三)——組合訓練數據