1. 程式人生 > >TensorFlow學習筆記(6)讀取數據

TensorFlow學習筆記(6)讀取數據

官網 com 記錄 mat target 項目 AD 包含 技術

Overview 之前幾次推送的全部例程,使用的都是tensorflow預處理過的數據集,直接載入即可。例如:技術分享圖片

然而實際中我們使用的通常不會是這種超級經典的數據集,如果我們有一組圖像存儲在磁盤上面,如何以mini-batch的形式把它們讀取進來然後高效的送進網絡訓練?這次推送我們首先用tensorflow最底層的API處理這個問題,後面推送介紹高層API。高層API是對底層的進一步封裝,用戶可以不必關心過多細節。不過了解一下比較底層的API還是有好處的。當你有一組自己的數據的時候,你需要經過以下兩個步驟:(1)將全部數據寫入一個後綴 .tfredords 的文件。

這個步驟涉及讀入->預處理->寫入tfrecords,對你的數據是什麽格式沒有要求。例如,如果你手中是圖像數據,那用opencv/PIL等接口讀入;如果是matlab數據(mat文件),那可以用h5py協議讀入,等等。不管如何讀入,最終都要寫入到統一的tfrecords文件中,以便用tensorflow提供的接口高效讀取。

(2)以mini-batch的形式從tfrecords中讀取數據,送到模型的placeholder中支持網絡訓練。

技術分享圖片

實驗設置 代碼中使用的數據是存在磁盤中的400張png圖像,也傳到了github上面,存在my_data路徑下面。部分如下:

技術分享圖片

代碼實現以下功能:制備tfrecords形式的數據集,然後再以mini-batch讀入,為了測試讀入是否成功,把讀入的數據顯示在tensorboard上面。技術分享圖片

制備tfrecords數據集 在上次推送中(Tensorboard),大部分代碼都是遵循API接口的固定“模式”寫就可以,這次也主要以這種方式進行,而不過多討論背後的理論細節。兩個輔助函數定義這倆輔助函數的目的完全是不想讓後面的代碼太冗長

技術分享圖片

讀取圖像&寫入tfrecords文件

技術分享圖片

幾點說明(1)讀取圖像文件的時候用到了glob和opencv兩個包。glob是將路徑下全部文件名一次性存到一個list中,方面後面逐個讀取;opencv則只是利用imread接口讀取圖像文件的。(2)和一切文件操作一樣,向tfrecords文件中寫入內容也需要建立一個writer對象,創建這個對象的是函數 tf.python_io.TFRecordWriter(3)feature是我們創建的一個字典對象,這裏面可以包含你想記錄的任何信息。在這裏我們存入了三對鍵值(key: value):image_raw(圖像數據,這個是核心內容),heigh(高),width(寬)。你也可以加入更多的信息,例如,通道數目,文件名等等。這些信息在後面讀取數據的時候都可以一並讀取出來。比如:在主程序中,你需要用到圖像的尺寸參數,那麽你可以將圖像和尺寸參數一起讀出。

(4)註意數據格式。圖像數據本身是8bit的,因此我們用前面定義的輔助函數 _bytes_feature_ 把數據轉化成tensorflow要求的tf.train.BytesList格式存入。實際中還會碰到圖像本身是以float形式存儲的,代碼就需要相應的變動,這個下次推送再說。

技術分享圖片

從tfrecords中載入nimi-batch定義函數:讀取一個樣本

技術分享圖片

幾點說明:(1)整個代碼過程很煩雜,因為是調用的底層API,不過都是固定寫法,其中的內部原理主頁菌一知半解,不敢在這裏隨便講(2)特別註意這裏這個字典對象的定義方式首先,這裏的三個key要和前面制備tfrecords時候一致;其次,註意數據格式,image_raw是8bit存儲的,所以讀取的時候限定tf.string類型,同理,height和width要限定tf.int64

(3)如前文所說,字典中存入的信息都可以通過key來讀取,上面的代碼只讀取了圖像信息,如果想獲取height的值,可以補充這樣一句代碼:

height = tf.decode_raw(features[‘height‘], tf.int64)

然後在函數返回值中把height也返回即可

(4)每一個樣本是以一維的形式從數據流中抽取出來的,所以需要reshape成原始尺寸

定義mini-batch

技術分享圖片

用前面定義的read_record獲取一個樣本,然後用tf.train.shuffle_batch來封裝一個mini-batch。tf.train.shuffle_batch會多次通過read_record抽取樣本,並且開辟一塊內存空間建立隊列(queue),將樣本洗牌打亂,空間開辟越大,數據混亂度會越高。控制洗牌的參數是capacity和min_after_dequeue,官網文檔中給出了這倆參數的取值建議,我粘貼到了代碼註釋中。註意:從最開始介紹tensorflow的時候主頁菌就在強調一個事情:任何東西在用Session運行之前都是沒有實際值的。這裏也不例外。在主程序部分,每一個step都要這麽一句代碼:

batch = sess.run(data_batch)

這個batch才是實際的數據,是可以feed給placeholder的

主程序部分 我們的主程序是讀取mini-batch然後用tensorboard顯示。

技術分享圖片

說明:有四行代碼必不可少session開頭的兩行:coord = tf.train.Coordinator()

threads = tf.train.start_queue_runners(sess=sess, coord=coord)

session結尾的兩行

coord.request_stop()

coord.join(threads)

至於內部機理,文檔寫的太模糊,主頁菌缺少計算機基礎理論知識,並沒有看懂

技術分享圖片

總結 我相信你可能已經看暈了......這部分太過瑣碎,細節很多,官方文檔裏面寫的也很模糊,對內部機理解釋的不到位。面對這種情況,主頁菌最初選擇的方法就是,親自嘗試,用幾乎一整天的時間摸索出了這一套代碼的套路。雖然對機理還是一知半解,但是對代碼思路十分清晰了,在自己的項目中能夠迅速擼出一套數據預處理的代碼。所以,主頁菌的建議就是,親自調通一套demo!技術分享圖片

下期預告

這次推送的數據是8bit的,然而如果我想用float格式存儲怎麽辦?(或者原始數據就是float格式的,總不能截斷成8bit來存儲吧......)雖然這部分內容不多,但是由於這次推送信息量夠大了,還是放到下次單獨說吧。艾伯特(http://www.aibbt.com/)國內第一家人工智能門戶

本次推送對應的源碼:

http://www.aibbt.com/a/19073.html

TensorFlow學習筆記(6)讀取數據