1. 程式人生 > >f.train.Coordinator 和入隊執行緒啟動器 tf.train.start_queue_runners

f.train.Coordinator 和入隊執行緒啟動器 tf.train.start_queue_runners

TensorFlow的Session物件是支援多執行緒的,可以在同一個會話(Session)中建立多個執行緒,並行執行。在Session中的所有執行緒都必須能被同步終止,異常必須能被正確捕獲並報告,會話終止的時候, 佇列必須能被正確地關閉。TensorFlow提供了兩個類來實現對Session中多執行緒的管理:tf.Coordinator和 tf.QueueRunner,這兩個類往往一起使用。

Coordinator類用來管理在Session中的多個執行緒,可以用來同時停止多個工作執行緒並且向那個在等待所有工作執行緒終止的程式報告異常,該執行緒捕獲到這個異常之後就會終止所有執行緒。使用 tf.train.Coordinator()來建立一個執行緒管理器(協調器)物件

QueueRunner類用來啟動tensor的入隊執行緒,可以用來啟動多個工作執行緒同時將多個tensor(訓練資料)推送入檔名稱佇列中,具體執行函式是 tf.train.start_queue_runners , 只有呼叫 tf.train.start_queue_runners 之後,才會真正把tensor推入記憶體序列中,供計算單元呼叫,否則會由於記憶體序列為空,資料流圖會處於一直等待狀態

tf中的資料讀取機制如下圖:

  1. 呼叫 tf.train.slice_input_producer,從 本地檔案裡抽取tensor,準備放入Filename Queue(檔名佇列)中;
  2. 呼叫 tf.train.batch,從檔名佇列中提取tensor,使用單個或多個執行緒,準備放入檔案佇列;
  3. 呼叫 tf.train.Coordinator() 來建立一個執行緒協調器,用來管理之後在Session中啟動的所有執行緒;
  4. 呼叫tf.train.start_queue_runners, 啟動入隊執行緒,由多個或單個執行緒,按照設定規則,把檔案讀入Filename Queue中。函式返回執行緒ID的列表,一般情況下,系統有多少個核,就會啟動多少個入隊執行緒(入隊具體使用多少個執行緒在tf.train.batch中定義);
  5. 檔案從 Filename Queue中讀入記憶體佇列的操作不用手動執行,由tf自動完成;
  6. 呼叫sess.run 來啟動資料出列和執行計算;
  7. 使用 coord.should_stop()來查詢是否應該終止所有執行緒,當檔案佇列(queue)中的所有檔案都已經讀取出列的時候,會丟擲一個 OutofRangeError 的異常,這時候就應該停止Sesson中的所有執行緒了;
  8. 使用coord.request_stop()來發出終止所有執行緒的命令,使用coord.join(threads)把執行緒加入主執行緒,等待threads結束。

以上對列(Queue)和 協調器(Coordinator)操作示例:

[python] view plain copy print?
  1. # -*- coding:utf-8 -*-
  2. import tensorflow as tf  
  3. import numpy as np  
  4. # 樣本個數
  5. sample_num=5
  6. # 設定迭代次數
  7. epoch_num = 2
  8. # 設定一個批次中包含樣本個數
  9. batch_size = 3
  10. # 計算每一輪epoch中含有的batch個數
  11. batch_total = int(sample_num/batch_size)+1
  12. # 生成4個數據和標籤
  13. def generate_data(sample_num=sample_num):  
  14.     labels = np.asarray(range(0, sample_num))  
  15.     images = np.random.random([sample_num, 2242243])  
  16.     print(‘image size {},label size :{}’.format(images.shape, labels.shape))  
  17.     return images,labels  
  18. def get_batch_data(batch_size=batch_size):  
  19.     images, label = generate_data()  
  20.     # 資料型別轉換為tf.float32
  21.     images = tf.cast(images, tf.float32)  
  22.     label = tf.cast(label, tf.int32)  
  23.     #從tensor列表中按順序或隨機抽取一個tensor準備放入檔名稱佇列
  24.     input_queue = tf.train.slice_input_producer([images, label], num_epochs=epoch_num, shuffle=False)  
  25.     #從檔名稱佇列中讀取檔案準備放入檔案佇列
  26.     image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64, allow_smaller_final_batch=False)  
  27.     return image_batch, label_batch  
  28. image_batch, label_batch = get_batch_data(batch_size=batch_size)  
  29. with tf.Session() as sess:  
  30.     # 先執行初始化工作
  31.     sess.run(tf.global_variables_initializer())  
  32.     sess.run(tf.local_variables_initializer())  
  33.     # 開啟一個協調器
  34.     coord = tf.train.Coordinator()  
  35.     # 使用start_queue_runners 啟動佇列填充
  36.     threads = tf.train.start_queue_runners(sess, coord)  
  37.     try:  
  38.         whilenot coord.should_stop():  
  39.             print‘************’
  40.             # 獲取每一個batch中batch_size個樣本和標籤
  41.             image_batch_v, label_batch_v = sess.run([image_batch, label_batch])  
  42.             print(image_batch_v.shape, label_batch_v)  
  43.     except tf.errors.OutOfRangeError:  #如果讀取到檔案佇列末尾會丟擲此異常
  44.         print(“done! now lets kill all the threads……”)  
  45.     finally:  
  46.         # 協調器coord發出所有執行緒終止訊號
  47.         coord.request_stop()  
  48.         print(‘all threads are asked to stop!’)  
  49.     coord.join(threads) #把開啟的執行緒加入主執行緒,等待threads結束
  50.     print(‘all threads are stopped!’)  
  1. # -*- coding:utf-8 -*-
  2. import tensorflow as tf
  3. import numpy as np
  4. # 樣本個數
  5. sample_num=5
  6. # 設定迭代次數
  7. epoch_num = 2
  8. # 設定一個批次中包含樣本個數
  9. batch_size = 3
  10. # 計算每一輪epoch中含有的batch個數
  11. batch_total = int(sample_num/batch_size)+1
  12. # 生成4個數據和標籤
  13. def generate_data(sample_num=sample_num):
  14. labels = np.asarray(range(0, sample_num))
  15. images = np.random.random([sample_num, 224, 224, 3])
  16. print('image size {},label size :{}'.format(images.shape, labels.shape))
  17. return images,labels
  18. def get_batch_data(batch_size=batch_size):
  19. images, label = generate_data()
  20. # 資料型別轉換為tf.float32
  21. images = tf.cast(images, tf.float32)
  22. label = tf.cast(label, tf.int32)
  23. #從tensor列表中按順序或隨機抽取一個tensor準備放入檔名稱佇列
  24. input_queue = tf.train.slice_input_producer([images, label], num_epochs=epoch_num, shuffle=False)
  25. #從檔名稱佇列中讀取檔案準備放入檔案佇列
  26. image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64, allow_smaller_final_batch=False)
  27. return image_batch, label_batch
  28. image_batch, label_batch = get_batch_data(batch_size=batch_size)
  29. with tf.Session() as sess:
  30. # 先執行初始化工作
  31. sess.run(tf.global_variables_initializer())
  32. sess.run(tf.local_variables_initializer())
  33. # 開啟一個協調器
  34. coord = tf.train.Coordinator()
  35. # 使用start_queue_runners 啟動佇列填充
  36. threads = tf.train.start_queue_runners(sess, coord)
  37. try:
  38. while not coord.should_stop():
  39. print '************'
  40. # 獲取每一個batch中batch_size個樣本和標籤
  41. image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
  42. print(image_batch_v.shape, label_batch_v)
  43. except tf.errors.OutOfRangeError: #如果讀取到檔案佇列末尾會丟擲此異常
  44. print("done! now lets kill all the threads……")
  45. finally:
  46. # 協調器coord發出所有執行緒終止訊號
  47. coord.request_stop()
  48. print('all threads are asked to stop!')
  49. coord.join(threads) #把開啟的執行緒加入主執行緒,等待threads結束
  50. print('all threads are stopped!')

輸出:

[python] view plain copy print?
  1. ************  
  2. ((32242243), array([012], dtype=int32))  
  3. ************  
  4. ((32242243), array([340], dtype=int32))  
  5. ************  
  6. ((32242243), array([123], dtype=int32))  
  7. ************  
  8. done! now lets kill all the threads……  
  9. all threads are asked to stop!  
  10. all threads are stopped!  
  1. ************
  2. ((3, 224, 224, 3), array([0, 1, 2], dtype=int32))
  3. ************
  4. ((3, 224, 224, 3), array([3, 4, 0], dtype=int32))
  5. ************
  6. ((3, 224, 224, 3), array([1, 2, 3], dtype=int32))
  7. ************
  8. done! now lets kill all the threads……
  9. all threads are asked to stop!
  10. all threads are stopped!

以上程式在 tf.train.slice_input_producer 函式中設定了 num_epochs 的數量, 所以在檔案佇列末尾有結束標誌,讀到這個結束標誌的時候丟擲 OutofRangeError 異常,就可以結束各個執行緒了。

如果不設定 num_epochs 的數量,則檔案佇列是無限迴圈的,沒有結束標誌,程式會一直執行下去。