1. 程式人生 > >TensorFlow數據集(一)——數據集的基本使用方法

TensorFlow數據集(一)——數據集的基本使用方法

cor mark get-data 註意 輸入輸出 create amp lose urn

參考書

《TensorFlow:實戰Google深度學習框架》(第2版)

例子:從一個張量創建一個數據集,遍歷這個數據集,並對每個輸入輸出y = x^2 的值。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: dataset_test1.py
@time: 2019/2/10 10:52
@desc: 例子:從一個張量創建一個數據集,遍歷這個數據集,並對每個輸入輸出y = x^2 的值。
""" import tensorflow as tf # 從一個數組創建數據集。 input_data = [1, 2, 3, 5, 8] dataset = tf.data.Dataset.from_tensor_slices(input_data) # 定義一個叠代器用於遍歷數據集。因為上面定義的數據集沒有用placeholder作為輸入參數 # 所以這裏可以使用最簡單的one_shot_iterator iterator = dataset.make_one_shot_iterator() # get_next() 返回代表一個輸入數據的張量,類似於隊列的dequeue()。 x = iterator.get_next() y
= x * x with tf.Session() as sess: for i in range(len(input_data)): print(sess.run(y))

運行結果:

技術分享圖片技術分享圖片?

數據是文本文件:創建數據集。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: dataset_test2.py
@time: 2019/2/10 11:03
@desc: 數據是文本文件
""" import tensorflow as tf # 從文本文件創建數據集。假定每行文字是一個訓練例子。註意這裏可以提供多個文件。 input_files = [./input_file11, ./input_file22] dataset = tf.data.TextLineDataset(input_files) # 定義叠代器用於遍歷數據集 iterator = dataset.make_one_shot_iterator() # 這裏get_next()返回一個字符串類型的張量,代表文件中的一行。 x = iterator.get_next() with tf.Session() as sess: for i in range(4): print(sess.run(x))

運行結果:

技術分享圖片技術分享圖片?

數據是TFRecord文件:創建TFRecord測試文件。

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: dataset_createdata.py
@time: 2019/2/10 13:59
@desc: 創建樣例文件
"""

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import time


# 生成整數型的屬性。
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


# 生成字符串型的屬性。
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


a = [11, 21, 31, 41, 51]
b = [22, 33, 44, 55, 66]


# 輸出TFRecord文件的地址
filename = ./input_file2
# 創建一個writer來寫TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(len(a)):
    aa = a[index]
    bb = b[index]
    # 將一個樣例轉化為Example Protocol Buffer,並將所有的信息寫入這個數據結構。
    example = tf.train.Example(features=tf.train.Features(feature={
        feat1: _int64_feature(aa),
        feat2: _int64_feature(bb)
    }))

    # 將一個Example寫入TFRecord文件中。
    writer.write(example.SerializeToString())
writer.close()

運行結果:

技術分享圖片技術分享圖片?

數據是TFRecord文件:創建數據集。(使用最簡單的one_hot_iterator來遍歷數據集)

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: dataset_test3.py
@time: 2019/2/10 13:16
@desc: 數據是TFRecord文件
"""

import tensorflow as tf


# 解析一個TFRecord的方法。record是從文件中讀取的一個樣例。前面介紹了如何解析TFRecord樣例。
def parser(record):
    # 解析讀入的一個樣例
    features = tf.parse_single_example(
        record,
        features={
            feat1: tf.FixedLenFeature([], tf.int64),
            feat2: tf.FixedLenFeature([], tf.int64),
        }
    )
    return features[feat1], features[feat2]


# 從TFRecord文件創建數據集。
input_files = [./input_file1, ./input_file2]
dataset = tf.data.TFRecordDataset(input_files)

# map()函數表示對數據集中的每一條數據進行調用相應方法。使用TFRecordDataset讀出的是二進制的數據。
# 這裏需要通過map()函數來調用parser()對二進制數據進行解析。類似的,map()函數也可以用來完成其他的數據預處理工作。
dataset = dataset.map(parser)

# 定義遍歷數據集的叠代器
iterator = dataset.make_one_shot_iterator()

# feat1, feat2是parser()返回的一維int64型張量,可以作為輸入用於進一步的計算。
feat1, feat2 = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        f1, f2 = sess.run([feat1, feat2])
        print(f1, f2)

運行結果:

技術分享圖片技術分享圖片?

數據是TFRecord文件:創建數據集。(使用placeholder和initializable_iterator來動態初始化數據集)

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# coding=utf-8 

"""
@author: Li Tian
@contact: [email protected]
@software: pycharm
@file: dataset_test4.py
@time: 2019/2/10 13:44
@desc: 用initializable_iterator來動態初始化數據集的例子
"""

import tensorflow as tf
from figuredata_deal.dataset_test3 import parser


# 解析一個TFRecord的方法。與上面的例子相同不再重復。
# 從TFRecord文件創建數據集,具體文件路徑是一個placeholder,稍後再提供具體路徑。
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parser)

# 定義遍歷dataset的initializable_iterator
iterator = dataset.make_initializable_iterator()
feat1, feat2 = iterator.get_next()

with tf.Session() as sess:
    # 首先初始化iterator,並給出input_files的值。
    sess.run(iterator.initializer, feed_dict={input_files: [./input_file1, ./input_file2]})

    # 遍歷所有數據一個epoch,當遍歷結束時,程序會拋出OutOfRangeError
    while True:
        try:
            sess.run([feat1, feat2])
        except tf.errors.OutOfRangeError:
            break

運行結果:

技術分享圖片技術分享圖片?

TensorFlow數據集(一)——數據集的基本使用方法