1. 程式人生 > >生成tfrecords格式數據和使用dataset API使用tfrecords數據

生成tfrecords格式數據和使用dataset API使用tfrecords數據

fixed 三種 res write pan set writer cor 合成

TFRecords是TensorFlow中的設計的一種內置的文件格式,它是一種二進制文件,優點有如下幾種:

  • 統一不同輸入文件的框架
  • 它是更好的利用內存,更方便復制和移動(TFRecord壓縮的二進制文件, protocal buffer序列化)
  • 是用於將二進制數據和標簽(訓練的類別標簽)數據存儲在同一個文件中

一、將其他數據存儲為TFRecords文件的時候,需要經過兩個步驟:

建立TFRecord存儲器

  在tensorflow中使用下面語句來簡歷tfrecord存儲器:

tf.python_io.TFRecordWriter(path)

path : 創建的TFRecords文件的路徑

方法:

  • write(record):向文件中寫入一個字符串記錄(即一個樣本)
  • close() : 在寫入所有文件後,關閉文件寫入器。

註:此處的字符串為一個序列化的Example,通過Example.SerializeToString()來實現,它的作用是將Example中的map壓縮為二進制,節約大量空間。

構造每個樣本的Example模塊

Example模塊的定義如下:

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature> feature = 1;
};

message Feature {
  oneof kind {
    BytesList bytes_list 
= 1; FloatList float_list = 2; Int64List int64_list = 3; } };

可以看到,Example中可以包括三種格式的數據:tf.int64,tf.float32和二進制類型。

features是以鍵值對的形式保存的。示例代碼如下:

example = tf.train.Example(
            features=tf.train.Features(feature={
                "label": tf.train.Feature(float_list=tf.train.FloatList(value=[string[1]])),
                
img_raw: tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), x1_offset:tf.train.Feature(float_list=tf.train.FloatList(value=[string[2]])), y1_offset: tf.train.Feature(float_list=tf.train.FloatList(value=[string[3]])), x2_offset: tf.train.Feature(float_list=tf.train.FloatList(value=[string[4]])), y2_offset: tf.train.Feature(float_list=tf.train.FloatList(value=[string[5]])), beta_det:tf.train.Feature(float_list=tf.train.FloatList(value=[string[6]])), beta_bbox:tf.train.Feature(float_list=tf.train.FloatList(value=[string[7]])) }))

構造好了Example模塊後,我們就可以將樣本寫入文件了:

writer.write(example.SerializeToString())

文件全部寫入後不要忘記關閉文件寫入器。

二、創建好我們自己的tfrecords文件後,我們就可以在訓練的時候使用它啦。tensorflow為我們提供了Dataset這個API以方便地使用tfrecords文件。

首先,我們要定義一個解析tfrecords的函數,它用來將二進制文件解析為張量。示例代碼如下:

def pares_tf(example_proto):
    #定義解析的字典
    dics = {
        label: tf.FixedLenFeature([], tf.float32),
        img_raw: tf.FixedLenFeature([], tf.string),
        x1_offset: tf.FixedLenFeature([], tf.float32),
        y1_offset: tf.FixedLenFeature([], tf.float32),
        x2_offset: tf.FixedLenFeature([], tf.float32),
        y2_offset: tf.FixedLenFeature([], tf.float32),
        beta_det: tf.FixedLenFeature([], tf.float32),
        beta_bbox: tf.FixedLenFeature([], tf.float32)}
    #調用接口解析一行樣本
    parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
    image = tf.decode_raw(parsed_example[img_raw],out_type=tf.uint8)
    image = tf.reshape(image,shape=[12,12,3])
    #這裏對圖像數據做歸一化
    image = (tf.cast(image,tf.float32)/255.0)
    label = parsed_example[label]
    label=tf.reshape(label,shape=[1])
    label = tf.cast(label,tf.float32)
    x1_offset=parsed_example[x1_offset]
    x1_offset = tf.reshape(x1_offset, shape=[1])
    y1_offset=parsed_example[y1_offset]
    y1_offset = tf.reshape(y1_offset, shape=[1])
    x2_offset=parsed_example[x2_offset]
    x2_offset = tf.reshape(x2_offset, shape=[1])
    y2_offset=parsed_example[y2_offset]
    y2_offset = tf.reshape(y2_offset, shape=[1])
    beta_det=parsed_example[beta_det]
    beta_det=tf.reshape(beta_det,shape=[1])
    beta_bbox=parsed_example[beta_bbox]
    beta_bbox=tf.reshape(beta_bbox,shape=[1])

    return image,label,x1_offset,y1_offset,x2_offset,y2_offset,beta_det,beta_bbox

接下來,我們需要使用tf.data.TFRecordDataset(filenames)讀入tfrecords文件。

一個Dataset通過Transformation變成一個新的Dataset。通常我們可以通過Transformation完成數據變換,打亂,組成batch,生成epoch等一系列操作。

常用的Transformation有:map、batch、shuffle、repeat。

map:

  map接收一個函數,Dataset中的每個元素都會被當作這個函數的輸入,並將函數返回值作為新的Dataset

batch:

  batch就是將多個元素組合成batch

repeat:

  repeat的功能就是將整個序列重復多次,主要用來處理機器學習中的epoch,假設原先的數據是一個epoch,使用repeat(5)就可以將之變成5個epoch

shuffle:

  shuffle的功能為打亂dataset中的元素,它有一個參數buffersize,表示打亂時使用的大小。

示例代碼:

dataset = tf.data.TFRecordDataset(filenames=[filename])
dataset = dataset.map(pares_tf)
dataset = dataset.batch(16).repeat(1)#整個序列只使用一次,每次使用16個樣本組成一個批次

現在這一個批次的樣本做好了,如何將它取出以用於訓練呢?答案是使用叠代器,在tensorflow中的語句如下:

iterator = dataset.make_one_shot_iterator()

所謂one_shot意味著只能從頭到尾讀取一次,那如何在每一個訓練輪次中取出不同的樣本呢?iterator的get_netxt()方法可以實現這一點。需要註意的是,這裏使用get_next()得到的只是一個tensor,並不是一個具體的值,在訓練的時候要使用這個值的話,我們需要在session裏面來取得。

使用dataset讀取tfrecords文件的完整代碼如下:

def pares_tf(example_proto):
    #定義解析的字典
    dics = {
        label: tf.FixedLenFeature([], tf.float32),
        img_raw: tf.FixedLenFeature([], tf.string),
        x1_offset: tf.FixedLenFeature([], tf.float32),
        y1_offset: tf.FixedLenFeature([], tf.float32),
        x2_offset: tf.FixedLenFeature([], tf.float32),
        y2_offset: tf.FixedLenFeature([], tf.float32),
        beta_det: tf.FixedLenFeature([], tf.float32),
        beta_bbox: tf.FixedLenFeature([], tf.float32)}
    #調用接口解析一行樣本
    parsed_example = tf.parse_single_example(serialized=example_proto,features=dics)
    image = tf.decode_raw(parsed_example[img_raw],out_type=tf.uint8)
    image = tf.reshape(image,shape=[12,12,3])
    #這裏對圖像數據做歸一化
    image = (tf.cast(image,tf.float32)/255.0)
    label = parsed_example[label]
    label=tf.reshape(label,shape=[1])
    label = tf.cast(label,tf.float32)
    x1_offset=parsed_example[x1_offset]
    x1_offset = tf.reshape(x1_offset, shape=[1])
    y1_offset=parsed_example[y1_offset]
    y1_offset = tf.reshape(y1_offset, shape=[1])
    x2_offset=parsed_example[x2_offset]
    x2_offset = tf.reshape(x2_offset, shape=[1])
    y2_offset=parsed_example[y2_offset]
    y2_offset = tf.reshape(y2_offset, shape=[1])
    beta_det=parsed_example[beta_det]
    beta_det=tf.reshape(beta_det,shape=[1])
    beta_bbox=parsed_example[beta_bbox]
    beta_bbox=tf.reshape(beta_bbox,shape=[1])

    return image,label,x1_offset,y1_offset,x2_offset,y2_offset,beta_det,beta_bbox

dataset = tf.data.TFRecordDataset(filenames=[filename])
dataset = dataset.map(pares_tf)
dataset = dataset.batch(16).repeat(1)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
        
    img, label, x1_offset, y1_offset, x2_offset, y2_offset, beta_det, beta_bbox = sess.run(fetches=next_element)

生成tfrecords格式數據和使用dataset API使用tfrecords數據