1. 程式人生 > >Tensorflow 資料讀取 tf.data.Dataset API 相關介紹

Tensorflow 資料讀取 tf.data.Dataset API 相關介紹

介紹

tf.1.4及以後新出的tf.data.Dataset API 中,使用的資料讀取方式有點類似於pytorch中的Dataloader,大大簡化了資料讀取。下面是程式碼例項。

# coding=utf-8
import os
import numpy as np
import glob

import tensorflow as tf
import tensorflow.contrib.eager as tfe

"""資料讀取: Dataset API的介紹"""
"""
1. Dataset API 支援tensorflow新出的Eager模式
            Eager模式:迭代時可直接取值,而不是tensor。但在tf 1.4的標準版中,沒有eager模式,而是在nightly version
2. 通過Dataset類可以例項化出一個Iterator
3. Dataset 可以看成是相同型別元素的有序列表。這裡的元素可以是向量,字串,圖片,或者tuple,dict等
4. 從Dataset中取出元素:
            需要例項化一個Interator,然後對Iterator進行迭代
5. Dataset支援一類特殊的操作: Transformation. 一個Dataset通過Transformation變成一個新的Dataset。
    我們可以通過Transformation完成 資料變換, 打亂, 組成batch, 生成epoch 等操作
    常用的Transformation:
                (1) map
                (2) batch
                (3) shuffle
                (4) repeat
6. dataset的建立方法:
    (1) tf.data.Dataset.from_tensor_slices
    (2) tf.data.TextLineDataset(): 輸入是一個檔案列表,輸出是一個dataset。dataset中的每一個元素就對應了檔案中的一行。
                                    可以用這個函式來讀取csv檔案
    (3) tf.data.FixedLengthRecordDataset(): 通常用來讀取以二進位制形式儲存的檔案,如CIFAR10資料集
    (4) tf.data.TFRecordDataset(): 用來讀取tfrecord檔案,dataset中的每一個元素就是一個TFExample
"""


def eager_dataset():
    """
    以eager模式讀取資料集
    :return: 
    """
    dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
    iterator = tfe.Iterator(dataset)
    for one_element in iterator:
        print(one_element)


def non_eager_dataset():
    """
    以非eager的方式讀取資料集
    :return: 
    """
    # from_tensor_slices: 切分傳入Tensor的第一個維度,生成相應的dataset
    dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))

    """非eager模式"""
    # 建立一個iterator,且是一個one shot iterator,即只能從頭到尾讀取一次
    iterator = dataset.make_one_shot_iterator()
    # 非Eager模式:one_element是一個tensor,而不是個實際的值
    one_element = iterator.get_next()

    # with tf.Session() as sess:
    #     for i in range(5):
    #         # 如果一個dataset中的元素被讀取完了,再嘗試執行sess.run(one_element),會報tf.errors.OutOfRangeError的異常
    #         print(sess.run(one_element))

    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


def non_eager_dataset_v2():
    dataset = tf.data.Dataset.from_tensor_slices(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()

    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


def non_eager_dataset_dict_classical():
    """
    經典的影象處理類問題中,image 和 label 的組織形式: 
                    {'image': image_tensor, 'label': label_tensor}
    :return: 
    """
    # from_tensor_slices 會分別切分'a','b'中的數值,最終dataset中的一個元素類似於{'a': 1.0, 'b': dog}的形式
    dataset = tf.data.Dataset.from_tensor_slices(
        {'a': np.array([1.0, 2.0, 3.0, 4.0, 5.0]), 'b': ['dog', 'cat', 'pig', 'monkey', 'bear']})
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


"""Transformation 相關操作"""
def map_fun():
    dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
    dataset = dataset.map(lambda x: x + 1)
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()

    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


def batch_fun():
    dataset = tf.data.Dataset.from_tensor_slices(np.array(range(32)))
    # 注: batch 也支援不整除的操作
    dataset = dataset.batch(5)
    dataset = dataset.shuffle(1000)
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    cnt = 0
    with tf.Session() as sess:
        try:
            while True:
                print('batch: {}, {}'.format(cnt, sess.run(one_element)))
                cnt += 1
        except tf.errors.OutOfRangeError:
            print('End')


def repeat_fun():
    dataset = tf.data.Dataset.from_tensor_slices(np.array(range(10)))
    dataset = dataset.shuffle(1000)
    # repeat 的功能就是將整個資料集重複多次,主要用來處理機器學習中的epoch.
    dataset = dataset.repeat(3)
    iterator = dataset.make_one_shot_iterator()
    one_element = iterator.get_next()
    with tf.Session() as sess:
        try:
            while True:
                print(sess.run(one_element))
        except tf.errors.OutOfRangeError:
            print('End')


"""一個經典的讀取image和label的列子"""
def parse_function(filename, label):
    image_string = tf.read_file(filename)
    # image_decoded = tf.image.decode_image(image_string, channels=3)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize_images(image_decoded, size=(100, 100))

    return image_resized, label


def dataset_classical_example():
    batch_size = 4

    filenames_tmp = glob.glob(os.path.join('./data_samples', '*.{}'.format('jpg')))
    filenames = tf.constant(filenames_tmp)
    labels = tf.constant(range(len(filenames_tmp)))

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(buffer_size=1000).batch(batch_size).repeat(3)

    iterator = dataset.make_one_shot_iterator()
    one_batch = iterator.get_next()

    with tf.Session() as sess:
        try:
            while True:
                batch_images, batch_labels = sess.run(one_batch)
        except tf.errors.OutOfRangeError:
            print('End')


if __name__ == '__main__':
    # non_eager_dataset_dict_classical()
    # map_fun()
    # batch_fun()

    # repeat_fun()
    dataset_classical_example()




參考連結: