1. 程式人生 > >Tensorflow兩種資料讀取方法應用、對比及注意事項

Tensorflow兩種資料讀取方法應用、對比及注意事項

本文對比介紹了兩種Tensorflow針對大資料集的資料讀取方法,具體來說是:

方法1:tf.train.slice_input_producer+流水線(這裡用這個API指代這一類,其實還有其他的API)。

方法2:Dataset方法,據說是Tensorflow 1.3版本之後引入的新API。使用起來比方法1更方便直觀。

關於這兩種方法的具體介紹,這裡不再贅述,建議參考何之源的這兩篇文章:

https://zhuanlan.zhihu.com/p/27238630

https://zhuanlan.zhihu.com/p/30751039

下面針對一組簡單的示例資料集,分別用這兩種方法實現資料讀取,以進行對比學習。

資料集非常簡單,就是下面這三張jpg圖片檔名分別為:a.jpg, b.jpg, c.jpg。來源是何之源新書《21個專案玩轉深度學習——基於TensorFlow的實踐詳解》第二章。


注意,這三張圖片的解析度不同,稍後會提到,這是使用中需要注意的一點。我們的目的是使用上述兩種方法把這三張圖分批讀入。

方法1,程式碼如下:

# coding:utf-8
# blog.csdn.net/foreseerwang
# QQ: 50834

import tensorflow as tf
import numpy as np

def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string, channels=3)
  image_resized = tf.image.resize_images(image_decoded,
                  tf.convert_to_tensor([28, 28], dtype=tf.int32))
  return image_resized, label

filename = ['A.jpg', 'B.jpg', 'C.jpg']
labels = [1,2,3]

images_tensor = tf.convert_to_tensor(filename, dtype=tf.string)
labels_tensor = tf.convert_to_tensor([1,2,3], dtype=tf.int64)
file_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=1)

filenames = file_queue[0]
labels = file_queue[1]
image_resized, labels = _parse_function(filenames, labels)

value = tf.train.shuffle_batch([image_resized, labels], batch_size=2, capacity=5000,
                       min_after_dequeue=1000)

batNum = 0
with tf.Session() as sess:
    tf.local_variables_initializer().run()

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        while not coord.should_stop():
            image_data = sess.run(value)
            batNum += 1
            print('** Batch %d' % batNum)
            print(image_data[0].shape)
            print(image_data[1])
            print('\n')

    except tf.errors.OutOfRangeError:
        print('end!')

    finally:
        coord.request_stop()

    coord.join(threads)

輸出:

** Batch 1
(2, 28, 28, 3)
[1 3]

end!

注意,這裡有個問題:一共3張圖片,要求每個batch有兩張,因為3不能被2整除,最後剩下一張圖片,沒能輸出。

方法2,程式碼如下:

# coding:utf-8
# blog.csdn.net/foreseerwang
# QQ: 50834

import tensorflow as tf

def _parse_function(filename, label):
  image_string = tf.read_file(filename)
  image_decoded = tf.image.decode_jpeg(image_string, channels=3)
  image_resized = tf.image.resize_images(image_decoded,
                  tf.convert_to_tensor([28, 28], dtype=tf.int32))
  return image_resized, label

filename = ['A.jpg', 'B.jpg', 'C.jpg']
labels = [1,2,3]

filename_tensor = tf.convert_to_tensor(filename, dtype=tf.string)
labels_tensor = tf.convert_to_tensor(labels, dtype=tf.int64)

dataset = tf.contrib.data.Dataset.from_tensor_slices((filename_tensor, labels_tensor))
dataset = dataset.map(_parse_function)

dataset = dataset.shuffle(buffer_size=10).batch(2).repeat(1)

iterator = dataset.make_one_shot_iterator()
one_batch = iterator.get_next()

batNum = 0
with tf.Session() as sess:
    try:
        while True:
            databatch = sess.run(one_batch)
            batNum += 1
            print('** Batch %d' % batNum)
            print(databatch[0].shape)
            print(databatch[1])
            print('\n')
            
    except tf.errors.OutOfRangeError:
            print("end!")

輸出(因為有shuffle,每次執行結果可能不同):

** Batch 1
(2, 28, 28, 3)
[1 3]

** Batch 2
(1, 28, 28, 3)
[2]

end!

這裡就看到方法2的優勢了,都是3張圖片,按照每個batch 2張輸出,Dataset方法就可以把所有資料都輸出,最後一點殘餘都不剩下。

這些程式碼都比較簡單,就不進一步解讀了。需要說明的是:

1. 以上程式碼均在Tensorflow 1.3版本下執行通過。據說在1.4版本以後,Dataset API被提高了層級,變為:tf.data.Dataset,在使用中還請注意。

2. 上面_parse_function()子函式中的tf.image.resize_images操作是必需的,必須把所有圖片修改為相同尺寸,否則在進行batch操作時會出錯。當然,要求是每個batch裡的資料維度必須相同,如果batch_size=1,那可以不進行圖片resize。

3. _parse_function()子函式其實是一個通用模組,在實際應用中根據需要完全可能是其它更復雜的資料讀取和處理過程。

通過以上程式碼可以看出,方法2 Dataset方法起碼具有如下兩個優勢:

1. 不用進行復雜的流水線管理,沒有coord/threads那些語句,應該是系統內自動管理了;

2. 可以把資料集完整的送出來,而不會留下因為資料條數無法對batch_size整除而剩下的尾巴資料(我不確定是否有什麼方法可以解決這個問題,如有人知道,還請指教。謝謝!)。

相關推薦

Tensorflow資料讀取方法應用對比注意事項

本文對比介紹了兩種Tensorflow針對大資料集的資料讀取方法,具體來說是:方法1:tf.train.slice_input_producer+流水線(這裡用這個API指代這一類,其實還有其他的API)。方法2:Dataset方法,據說是Tensorflow 1.3版本之後

Hbase表資料備份方法-匯入和匯出示例[未測試]

浪費了“黃金五年”的Java程式設計師,還有救嗎? >>>   

Ubuntu下protobuf的安裝編譯注意事項

Ubuntu下編譯protobuf 1.下載protobuf 下載地址:https://github.com/google/protobuf/releases 2.編譯protobuf 解壓下載的zip包,cd到protobuf的目錄下,執行以下指令: tar -xvf pro

tokuDB儲存引擎的安裝配置使用注意事項

平臺支援: TokuDB 只能在 64-bit Linux 下使用. 最小記憶體要求: TokuDB 需要至少1GB 的實體記憶體,當然這裡建議最好是 2GB . 1、安裝mysql(可以是mariadb、percona或oracle mysql) 安裝p

二叉樹常用操作演算法集解釋注意事項

二叉樹是一種常用的資料結構,在程式中也經常需要使用二叉樹,但是你所使用語言卻並不一定提供了二叉樹這種資料型別,所以為了方便使用,我們可以自己實現一個二叉樹的資料型別。在需要時就像使用其他已定義的型別一樣方便。 下面給出一些本人寫的演算法和解釋(基於C語言),希望對讀者寫一

MySql指令用法注意事項

MySql指令 查看錶結構 desc table_name; 修改某一列的資料型別 假設表stu有一列id原本的資料型別是int(11)將其型別改為varchar(20)的語句如下:

android應用 獲取本地指定型別檔案 的最優方法

分享一下我老師大神的人工智慧教程!零基礎,通俗易懂!http://blog.csdn.net/jiangjunshow 也歡迎大家轉載本篇文章。分享知識,造福人民,實現我們中華民族偉大復興!        

TensorFlow 資料讀取方法總結

作者:黑暗星球 原文地址:https://blog.csdn.net/u014061630/article/details/80712635 ====================下一篇:tf.data 官方教程==================== ==============

前端生成與讀取Excel的框架使用方法

xlsx 安裝依賴: 1. <script lang="javascript" src="dist/xlsx.full.min.js"></script> // 此檔案可以在上

資料結構——圖的遍歷方法

遍歷定義:從已給的圖中某一頂點出發,沿著一些邊,訪遍圖中所有的頂點,且使每個頂點僅被訪問一次,就叫做圖的遍歷。 遍歷實質:找每個頂點的鄰接點的過程。 圖的特點:圖中可能存在迴路,且圖的任一頂點都可能與其它頂點相通,在訪問完某個頂點之後可能會沿著某些邊又回到了曾經    &

android 分頁資料 LinkedHashSetArrayList去重方法

第一種(mDatasTemp為當前頁資料,deals為總資料) if (mDatasTemp != null && mDatasTemp.size() > 0) { boolean isEqual = false; for (UserD

springBoot使用Controller讀取配置檔案方式&讀取自定義配置方法

Controller 核心配置檔案 application.propertie web.msg=Hello! This is Controller demo; Controller:

python的影象讀取和處理方法

 1 skimage庫      from skimage import data_dir,io, data   img = data.camera()   #讀取data中預先儲存的影象camera                         #Anaconda 版本參

TensorFlow資料讀取方法

轉自:http://honggang.io/2016/08/19/tensorflow-data-reading/ 引言 Tensorflow的資料讀取有三種方式: Preloaded data: 預載入資料Feeding: Python產生資料,再把資料餵給後端

GET和POST基本請求方法的區別

完整性 總結 優化 返回 們的 傳遞參數 src 本質 大成 GET和POST是HTTP請求的兩種基本方法,要說它們的區別,接觸過WEB開發的人都能說出一二。 最直觀的區別就是GET把參數包含在URL中,POST通過request body傳遞參數。 你可能自己寫過

Apache 工作模式 :prefork worker

文章 serve 共享 生產 servers sta 能力 一個 請求 前言 1·最近這幾篇文章都在講Apache的一些安全與優化,這些針對服務器都是很重要的,掌握這些不僅提升了服務器的安全,還讓服務器的性能大大的提高。這樣就可以讓客戶有很好的體驗感。2·今天這篇文章依

map的遍歷方法

str string get set entryset map.entry map val entry Set<String> keySet=maps.keySet();for(String string:keySet){ maps.get(string)} f

angualr routerLink 傳參方法引數的使用

1.路徑:http://localhost:8080/#/product?id=1 <a [routerLink]="['/product']" [queryParams]="{id:1}">詳情</a> ts獲取查詢引數: import { Component,

Vue router-link 傳參方法引數的使用

1.路徑:http://localhost:8080/#/detail?detail_id=1 <router-link :to="{path:'/detail',query: {detail_id: id}}">跳轉</router-link>  (id是引數

numpy transpose 和 swapaxes轉置方法

轉自:https://blog.csdn.net/caca95/article/details/81428673 transpose: 已3維舉例, 0 代表 x 軸,1 代表 y 軸,2 代表 z 軸, transpose(0,1,2)  可以理解維 tanspose(