1. 程式人生 > >tensorflow學習(9):TFRecord介紹和樣例程式(附詳細解讀)

tensorflow學習(9):TFRecord介紹和樣例程式(附詳細解讀)

由於影象的亮度、對比度等屬性對影象的影響是非常大的,相同物體在不同亮度、對比度下差別非常大,然而在很多影象識別問題中,這些因素都不應該影響最後的識別結果。因此,本文將介紹如何對影象資料處理進行預處理使訓練得到的神經網路模型儘可能小的被無關因素影響。

由於來自實際問題的資料往往有很多格式和屬性,我們將使用TFRecord格式來統一不同的原始資料格式,並更加有效的管理不同的屬性。

一、TFRecord格式

TFRecord檔案中的資料都是通過tf.train.Example Protocol Buffer格式儲存的,以下程式碼給出了tf.train.Example的定義

message Example{
	Features features = 1
};

message Features{
	map<string, Feature> feature = 1
};

message Feature{
	oneof kind{
		ByteList bytes_list = 1;
		FloatList float_list = 1;
		Int64List int64_lsit = 1;
		}
};

從以上程式碼中可以看出tf.train.Example的資料結構是比較簡潔的。tf.train.Example中包含了一個從屬性名稱到取值的字典。其中屬性名稱為一個字串,屬性的取值可以是字串、實數列表、整數列表。比如將一張解碼前的影象儲存為一個字串,影象所對應的類別編號儲存為整數列表。

二、將資料寫入TFRecord

import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

#生成整數型的屬性
def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value=[value]))

#生成字串型的屬性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value=[value]))
#path是儲存已經下載好的mnist資料集的路徑
mnist = input_data.read_data_sets("path",dtype=tf.uint8,one_hot=True)
images = mnist.train.images
#訓練資料所對應的正確答案,可以作為一個屬性儲存在TFRecord
labels = mnist.train.labels
#訓練資料的影象的解析度,可以作為Example中的一個屬性
pixels = images.shape[1]
num_examples = mnist.train.num_examples

#輸出TFRecord檔案的地址
filename = '/path/to/output.tfrecords'
#建立一個writer來寫TFRecord檔案
writer = tf.python_io.TFRecordWriter(filename)

for index in range(num_examples):
    #將影象矩陣轉化為一個字串
    image_raw = images[index].tostring()
    #將一個樣例轉化為example protocol buffer,並將所有資訊寫入這個資料結構
    example= tf.train.Example(features = tf.train.Features(feature={
            'pixels':_int64_feature(pixels),
            'label':_int64_feature(np.argmax(labels[index])),
            'image_raw':_bytes_feature(image_raw)
    }))
    #將一個example寫入TFRecord檔案
    writer.write(example.SerializeToString())
writer.close()

三、從TFRecord讀出資料

import tensorflow as tf

#建立一個reader來讀取TFRecord檔案中的樣例
reader = tf.TFRecordReader()
#建立一個佇列來維護檔案列表
filename_queue = tf.train.string_input_producer(['output.tfrecords'])#儲存路徑

#從一個檔案中讀出一個樣例,也可以使用read_up_to函式一次性讀取多個樣例
_, serialized_example = reader.read(filename_queue)
#解析讀入的一個樣例,如果需要解析多個樣例,可以用parse_example函式
features = tf.parse_single_example(serialized_example,
            features = {
                'image_raw':tf.FixedLenFeature([],tf.string),
                'pixels':tf.FixedLenFeature([],tf.int64),
                'label':tf.FixedLenFeature([],tf.int64)
            })

#tf.decode_raw可以將字串解析成影象對應的畫素陣列
image = tf.decode_raw(features['image_raw'],tf.uint8)
label = tf.cast(features['label'],tf.int32)
pixels = tf.cast(features['pixels'],tf.int32)

sess=tf.Session()
#啟動多執行緒處理資料
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess = sess,coord=coord)

#每次執行可以讀取TFRecord檔案中的一個樣例,當所有樣例都讀完之後,在此樣例中程式會再從頭讀取
for i in range(10):
    print(sess.run([image,label,pixels]))