1. 程式人生 > >tensorflow中的tf.train.batch詳解

tensorflow中的tf.train.batch詳解

官方文件連結:https://tensorflow.google.cn/versions/r1.8/api_docs/python/tf/train/batch

tf.train.batch(
    tensors,
    batch_size,
    num_threads=1,
    capacity=32,
    enqueue_many=False,
    shapes=None,
    dynamic_pad=False,
    allow_smaller_final_batch=False,
    shared_name=None,
    name=None
)

函式功能:利用一個tensor的列表或字典來獲取一個batch資料

引數介紹:

  • tensors:一個列表或字典的tensor用來進行入隊
  • batch_size:設定每次從佇列中獲取出隊資料的數量
  • num_threads:用來控制入隊tensors執行緒的數量,如果num_threads大於1,則batch操作將是非確定性的,輸出的batch可能會亂序
  • capacity:一個整數,用來設定佇列中元素的最大數量
  • enqueue_many:在tensors中的tensor是否是單個樣本
  • shapes:可選,每個樣本的shape,預設是tensors的shape
  • dynamic_pad:Boolean值.允許輸入變數的shape,出隊後會自動填補維度,來保持與batch內的shapes相同
  • allow_samller_final_batch:可選,Boolean值,如果為True佇列中的樣本數量小於batch_size時,出隊的數量會以最終遺留下來的樣本進行出隊,如果為Flalse,小於batch_size的樣本不會做出隊處理
  • shared_name:可選,通過設定該引數,可以對多個會話共享佇列
  • name:可選,操作的名字

從陣列中每次獲取一個batch_size的資料

import numpy as np
import tensorflow as tf

def next_batch():
    datasets =  np.asarray(range(0,20))
    input_queue = tf.train.slice_input_producer([datasets],shuffle=False,num_epochs=1)
    data_batchs = tf.train.batch(input_queue,batch_size=5,num_threads=1,
                                             capacity=20,allow_smaller_final_batch=False)
    return data_batchs

if __name__ == "__main__":
    data_batchs = next_batch()
    sess = tf.Session()
    sess.run(tf.initialize_local_variables())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess,coord)
    try:
        while not coord.should_stop():
            data = sess.run([data_batchs])
            print(data)
    except tf.errors.OutOfRangeError:
        print("complete")
    finally:
        coord.request_stop()
    coord.join(threads)
    sess.close()

注意:tf.train.batch這個函式的實現是使用queue,queue的QueueRunner被新增到當前計算圖的"QUEUE_RUNNER"集合中,所在使用初始化器的時候,需要使用tf.initialize_local_variables(),如果使用tf.global_varialbes_initialize()時,會報: Attempting to use uninitialized value