1. 程式人生 > >tensorflow中的佇列和執行緒

tensorflow中的佇列和執行緒

一、佇列

tensorflow中主要有FIFOQueue和RandomShuffleQueue兩種佇列,下面就詳細介紹這兩種佇列的使用方法和應用場景。

1、FIFOQueue

FIFOQueue是先進先出佇列,主要是針對一些序列樣本。如:在使用迴圈神經網路的時候,需要處理語音、文字、視訊等序列資訊的時候,我們希望處理的時候能夠按照順序進行,這時候就需要使用FIFOQueue佇列。

    #先入先出佇列,初始化佇列,設定佇列大小5
    q = tf.FIFOQueue(5,"float")
    #入隊操作
    init = q.enqueue_many(([1,2,3,4,5],))
    #定義出隊操作
    x = q.dequeue()
    y = x + 1
    #將出隊的元素加1,然後再加入到佇列中
    q_in = q.enqueue([y])
    #建立會話
    with tf.Session() as sess:
        sess.run(init)
        #執行3次q_in操作
        for i in range(3):
            sess.run(q_in)
        #獲取佇列的長度
        que_len = sess.run(q.size())
        #將佇列中的所有元素執行出隊操作
        for i in range(que_len):
            print(sess.run(q.dequeue()))

2、RandomShuffleQueue

RandomShuffleQueue是隨機佇列,佇列在執行出隊操作的時候,是以隨機的順序進行的。隨機佇列一般應用在我們訓練模型的時候,希望可以無序的獲取樣本來進行訓練,如:在訓練影象分類模型的時候,需要輸入的樣本是無序的,就可以利用多執行緒來讀取樣本,將樣本放到隨機佇列中,然後再利用主執行緒每次從隨機佇列中獲取一個batch進行模型的訓練。

    #初始化一個隨機佇列,設定佇列大小為10,最小長度為2
    q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes="float")
    #建立會話
    with tf.Session() as sess:
        #定義10次入隊操作
        for i in range(10):
            sess.run(q.enqueue(i))
        #定義8次出隊操作
        for i in range(8):
            print(sess.run(q.dequeue()))

注意:在使用隨機佇列的時候,我們設定了佇列的容量為10,最小長度為2。當佇列的長度已經等於佇列的容量(10)再執行入隊操作或佇列的長度已經等於最小長度(2)再執行出隊操作時,程式會發生阻斷,即程式在執行,但是沒有任何輸出,如下圖:

定義了10次出隊操作,當隊列出隊8次之後,就被阻斷了。我們可以通過設定會話在執行時的等待時間來解除阻斷:

    #初始化一個隨機佇列,設定佇列大小為10,最小長度為2
    q = tf.RandomShuffleQueue(capacity=10,min_after_dequeue=2,dtypes="float")
    #建立會話
    with tf.Session() as sess:
        #定義10次入隊操作
        for i in range(10):
            sess.run(q.enqueue(i))
        #設定會話執行時等待時間,等待時長為5s
        run_options = tf.RunOptions(timeout_in_ms=5000)
        #定義10次出隊操作
        for i in range(10):
            try:
                #當佇列進入阻斷之後,超時就丟擲異常
                print(sess.run(q.dequeue(),options=run_options))
            except tf.errors.DeadlineExceededError:
                print("out of range")
                #退出迴圈
                break

當隊列出隊第9次的時候,進入阻斷狀態時,我們可以通過DeadlineExceededError來捕獲阻斷資訊。

二、佇列管理器

在訓練模型的時候,我們需要將樣本從硬碟讀取到記憶體之後,才能進行訓練。會話中可以執行多個執行緒,我們可以在佇列管理器中建立一系列新的執行緒進行入隊操作,主執行緒可以利用佇列中的資料進行訓練,而不需要等到所有的樣本都讀取完成之後才開始訓練,即資料的讀取和模型的訓練是非同步的,這樣可以節省不少時間。

    #建立佇列,設定佇列的大小為1000
    q = tf.FIFOQueue(1000,"float")
    #定義計數器
    counter = tf.Variable(0.0)
    #給計數器加1
    increment_op = tf.assign_add(counter,tf.constant(1.0))
    #佇列入隊操作
    enque_op = q.enqueue([counter])
    #建立佇列管理器
    qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enque_op]*1)
    #建立會話
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        #啟動入隊執行緒
        enqueue_threads = qr.create_threads(sess,start=True)
        #主執行緒
        for i in range(10):
            #定義出隊操作
            print(sess.run(q.dequeue()))

程式結束的時候,還報了一個tensorflow.python.framework.errors_impl.CancelledError: Enqueue operation was cancelled的異常。那是因為主執行緒已經完成了,入隊執行緒還在繼續執行導致程式沒法結束從而報錯。由於計數器加1操作和入隊操作不同步,可能會由於計數器還沒來得及進行加1操作就再次被執行入隊操作,從而導致多次入隊同樣的數字,也就是為什麼出隊的時候會出現同樣的數字。

三、協調器

為了避免上述異常的發生,我們可以通過協調器來實現執行緒間的同步,來終止其他執行緒。

    #建立佇列,設定佇列的大小為1000
    q = tf.FIFOQueue(1000,"float")
    #定義計數器
    counter = tf.Variable(0.0)
    #給計數器加1
    increment_op = tf.assign_add(counter,tf.constant(1.0))
    #佇列入隊操作
    enque_op = q.enqueue([counter])
    #建立佇列管理器
    qr = tf.train.QueueRunner(q,enqueue_ops=[increment_op,enque_op]*1)
    #建立會話
    with tf.Session() as sess:
        #初始化變數
        sess.run(tf.global_variables_initializer())
        #建立一個執行緒協調器
        coord = tf.train.Coordinator()
        #啟動入隊執行緒
        enqueue_threads = qr.create_threads(sess,coord=coord,start=True)
        #主執行緒執行出隊操作
        for i in range(10):
            print(sess.run(q.dequeue()))
        #通知其他執行緒關閉
        coord.request_stop()
        #等待其他執行緒結束,當其他執行緒都關閉之後,函式才返回結果
        coord.join(enqueue_threads)

通過上面的結果可以發現,程式能夠正常的結束。但是,當關閉執行緒之後再執行出隊操作,就會報OutOfRangeError的錯誤,程式碼如下

        coord.request_stop()
        for i in range(10):
            print(sess.run(q.dequeue()))
        coord.join(enqueue_threads)

對於這種情況,我們可以通過OutOfRangeError來捕獲這個錯誤資訊

        coord.request_stop()
        for i in range(10):
            try:
                print(sess.run(q.dequeue()))
            except tf.errors.OutOfRangeError:
                #退出迴圈
                break
        coord.join(enqueue_threads)