1. 程式人生 > >TensorFlow之TFRecords檔案的儲存與讀取講解及程式碼實現

TensorFlow之TFRecords檔案的儲存與讀取講解及程式碼實現

先聊一下tfrecord, 這是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,TFRecords是TensorFlow中的設計的一種內建的檔案格式,優點有如下幾種:

  • 統一不同輸入檔案的框架
  • 它是更好的利用記憶體,更方便複製和移動(TFRecord壓縮的二進位制檔案, protocal buffer序列化)
  • 是用於將二進位制資料和標籤(訓練的類別標籤)資料儲存在同一個檔案中

一、TFRecords儲存

在將其他資料儲存為TFRecords檔案的時候,需要經過兩個步驟:

  • 建立TFRecord儲存器
  • 構造每個樣本的Example模組

1、建立TFRecord儲存器

tf.python_io.TFRecordWriter(path)

  • 寫入tfrecords檔案
  • path : TFRecords檔案的路徑
  • return : 寫檔案
  • 方法: 
    • write(record):向檔案中寫入一個字串記錄(即一個樣本)
    • close() : 關閉檔案寫入器

注:此處的字串為一個序列化的Example,通過Example.SerializeToString()來實現,它的作用是將Example中的map壓縮為二進位制,節約大量空間。

2、構造每個樣本的Example協議塊

message Example {
  Features features = 1;
};

message Features {
  map<string, Feature>
feature = 1; }; message Feature { oneof kind { BytesList bytes_list = 1; FloatList float_list = 2; Int64List int64_list = 3; } };

上面這段程式碼即為Example協議塊的規則,詳解如下: 
(1)tf.train.Example(features = None)

  • 寫入tfrecords檔案
  • features : tf.train.Features型別的特徵例項
  • return : example協議格式塊

(2)tf.train.Features(feature = None)

  • 構造每個樣本的資訊鍵值對
  • feature : 字典資料,key為要儲存的名字,value為tf.train.Feature例項
  • return : Features型別

(3)tf.train.Feature(**options) 
options可以選擇如下三種格式資料:

  • bytes_list = tf.train.BytesList(value = [Bytes])
  • int64_list = tf.train.Int64List(value = [Value])
  • float_list = tf.trian.FloatList(value = [Value])

(4)將圖片資料轉化為TFRecords的例子: 
對每一個樣本,都做如下的處理:

example = tf.train.Example(feature = tf.train.Features(feature = {
                            "image":tf.train.Feature(bytes_list=tf.train.BytesList(value=[image(bytes)]))
                             "label":tf.train.Feature(int64_list=tf.train.Int64List(value=[label(int)]))
    }))
  • 二、TFRecords讀取方法

1.流程:

和檔案閱讀器的流程基本相同,只是中間多了一步解析過程

2.解析TFRecords的example協議記憶體塊:

(1)tf.parse_single_example(serialized,features=None,name= None

  • 解析一個單一的Example原型
  • serialized : 標量字串的Tensor,一個序列化的Example,檔案經過檔案閱讀器之後的value
  • features :字典資料,key為讀取的名字,value為FixedLenFeature
  • return : 一個鍵值對組成的字典,鍵為讀取的名字

(2)tf.FixedLenFeature(shape,dtype)

  • shape : 輸入資料的形狀,一般不指定,為空列表
  • dtype : 輸入資料型別,與儲存進檔案的型別要一致,型別只能是float32,int 64, string
  • return : Tensor (即使有零的部分也儲存)

(3)上面(1)中features中的value還可以為tf.VarLenFeature(),但是這種方式用的比較少,它返回的是SparseTensor資料,這是一種只儲存非零部分的資料格式,瞭解即可。

三、程式碼實現

1.將CSV檔案轉化為TFRecords檔案

import tensorflow as tf
import numpy as np
import pandas as pd

train_frame = pd.read_csv("train.csv")
print(train_frame.head())
train_labels_frame = train_frame.pop(item="label")
train_values = train_frame.values
train_labels = train_labels_frame.values
print("values shape: ", train_values.shape)
print("labels shape:", train_labels.shape)

writer = tf.python_io.TFRecordWriter("csv_train.tfrecords")

for i in range(train_values.shape[0]):
    image_raw = train_values[i].tostring()
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[train_labels[i]]))
            }
        )
    )
    writer.write(record=example.SerializeToString())

writer.close()

2.將圖片檔案轉化為TFRecords檔案

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
import tensorflow as tf
import pandas as pd

def get_label_from_filename(filename):
    return 1

filenames = tf.train.match_filenames_once('.\data\*.png')

writer = tf.python_io.TFRecordWriter('png_train.tfrecords')

for filename in filenames:
    img=mpimg.imread(filename)
    print("{} shape is {}".format(filename, img.shape))
    img_raw = img.tostring()
    label = get_label_from_filename(filename)
    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                "image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_raw])),
                "label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
            }
        )
    )
    writer.write(record=example.SerializeToString())

writer.close()

3.將二進位制檔案轉化為TFRecords檔案

"""
讀取二進位制檔案轉換成張量,寫進TFRecords,同時讀取TFRcords
"""

#命令列引數
FLAGS = tf.app.flags.FLAGS       #獲取值
tf.app.flags.DEFINE_string("tfrecord_dir","./tmp/cifar10.tfrecords","寫入圖片資料檔案的檔名")


#讀取二進位制轉換檔案
class CifarRead(object):
    """
    讀取二進位制檔案轉換成張量,寫進TFRecords,同時讀取TFRcords
    """
    def __init__(self,file_list):
        """
        初始化圖片引數
        :param file_list:圖片的路徑名稱列表
        """

        #檔案列表
        self.file_list = file_list

        #圖片大小,二進位制檔案位元組數
        self.height = 32
        self.width = 32
        self.channel = 3
        self.label_bytes = 1
        self.image_bytes = self.height * self.width * self.channel
        self.bytes = self.label_bytes + self.image_bytes


    def read_and_decode(self):
        """
        解析二進位制檔案到張量
        :return: 批處理的image,label張量
        """
        #1.構造檔案佇列
        file_queue = tf.train.string_input_producer(self.file_list)

        #2.閱讀器讀取內容
        reader = tf.FixedLengthRecordReader(self.bytes)

        key ,value = reader.read(file_queue)    #key為檔名,value為元組

        print(value)

        #3.進行解碼,處理格式
        label_image = tf.decode_raw(value,tf.uint8)
        print(label_image)

        #處理格式,image,label
        #進行切片處理,標籤值
        #tf.cast()函式是轉換資料格式,此處是將label二進位制資料轉換成int32格式
        label = tf.cast(tf.slice(label_image,[0],[self.label_bytes]),tf.int32)

        #處理圖片資料
        image = tf.slice(label_image,[self.label_bytes],[self.image_bytes])
        print(image)

        #處理圖片的形狀,提供給批處理
        #因為image的形狀已經固定,此處形狀用動態形狀來改變
        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])
        print(image_tensor)

        #批處理圖片資料
        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)

        return image_batch,label_batch

    def write_to_tfrecords(self,image_batch,label_batch):
        """
        將檔案寫入到TFRecords檔案中
        :param image_batch:
        :param label_batch:
        :return:
        """

        #建立TFRecords檔案儲存器
        writer = tf.python_io.TFRecordWriter(FLAGS.tfrecord_dir)      #傳進去命令列引數

        #迴圈取出每個樣本的值,構造example協議塊
        for i in range(10):

            #取出圖片的值,  #寫進去的是值,而不是tensor型別,
            # 寫入example需要bytes檔案格式,將tensor轉化為bytes用tostring()來轉化
            image = image_batch[i].eval().tostring()

            #取出標籤值,寫入example中需要使用int形式,所以需要強制轉換int
            label = int(label_batch[i].eval()[0])

            #構造每個樣本的example協議塊
            example = tf.train.Example(features = tf.train.Features(feature = {
                "image":tf.train.Feature(bytes_list = tf.train.BytesList(value = [image])),
                "label":tf.train.Feature(int64_list = tf.train.Int64List(value = [label]))
            }))

            #寫進去序列化後的值
            writer.write(example.SerializeToString())     #此處其實是將其壓縮成一個二進位制資料

        writer.close()

        return None



    def read_from_tfrecords(self):
        """
        從TFRecords檔案當中讀取圖片資料(解析example)
        :param self:
        :return: image_batch,label_batch
        """

        #1.構造檔案佇列
        file_queue = tf.train.string_input_producer([FLAGS.tfrecord_dir])    #引數為檔名列表

        #2.構造閱讀器
        reader = tf.TFRecordReader()

        key,value = reader.read(file_queue)

        #3.解析協議塊,返回的值是字典
        feature = tf.parse_single_example(value,features={
            "image":tf.FixedLenFeature([],tf.string),
            "label":tf.FixedLenFeature([],tf.int64)
        })

        #feature["image"],feature["label"]
        #處理標籤資料    ,cast()只能在int和float之間進行轉換
        label = tf.cast(feature["label"],tf.int32)    #將資料型別int64 轉換為int32

        #處理圖片資料,由於是一個string,要進行解碼,  #將位元組轉換為數字向量表示,位元組為一字串型別的張量
        #如果之前用了tostring(),那麼必須要用decode_raw()轉換為最初的int型別
        # decode_raw()可以將資料從string,bytes轉換為int,float型別的
        image = tf.decode_raw(feature["image"],tf.uint8)

        #轉換圖片的形狀,此處需要用動態形狀進行轉換
        image_tensor = tf.reshape(image,[self.height,self.width,self.channel])

        #4.批處理
        image_batch,label_batch = tf.train.batch([image_tensor,label],batch_size=10,num_threads=1,capacity=10)

        return image_batch,label_batch


if __name__ == '__main__':

    # 找到檔案路徑,名字,構造路徑+檔名的列表,"A.csv"...
    # os.listdir() 方法用於返回指定的資料夾包含的檔案或資料夾的名字的列表
    filename = os.listdir('./data/cifar10/cifar-10-batches-bin/')

    #加上路徑
    file_list = [os.path.join('./data/cifar10/cifar-10-batches-bin/', file) for file in filename if file[-3:] == "bin"]

    #初始化引數
    cr = CifarRead(file_list)

    #讀取二進位制檔案
    # image_batch,label_batch = cr.read_and_decode()

    #從已經儲存的TFRecords檔案中解析出原始資料
    image_batch, label_batch = cr.read_from_tfrecords()

    with tf.Session() as sess:
        #執行緒協調器
        coord = tf.train.Coordinator()

        #開啟執行緒
        threads = tf.train.start_queue_runners(sess,coord=coord)

        print(sess.run([image_batch,label_batch]))

        # print("存進TFRecords檔案")
        # cr.write_to_tfrecords(image_batch,label_batch)
        # print("存進檔案完畢")

        #回收執行緒
        coord.request_stop()
        coord.join(threads)
  • 注:

上段程式碼分為兩個部分:

  • 第一部分是被註釋掉的幾行程式碼,表示的是將二進位制檔案轉化為張量,並經過Example協議儲存到TFRecords檔案當中;
  • 第二部分是從已經儲存好資料資訊的TFRecords檔案中,經過解析,轉化為最初的二進位制檔案。

參考地址:https://blog.csdn.net/chengshuhao1991/article/details/78656724

相關推薦

TensorFlowTFRecords檔案儲存讀取講解程式碼實現

先聊一下tfrecord, 這是一種將影象資料和標籤放在一起的二進位制檔案,能更好的利用記憶體,TFRecords是TensorFlow中的設計的一種內建的檔案格式,優點有如下幾種:統一不同輸入檔案的框架它是更好的利用記憶體,更方便複製和移動(TFRecord壓縮的二進位制檔

TensorFlow檔案儲存讀取——variables_to_restore函式

轉,原創詳見: http://blog.csdn.net/sinat_29957455/article/details/78508793variables_to_restore函式,是TensorFlow為滑動平均值提供。之前,也介紹過通過使用滑動平均值可以讓神經網路模型更加的健壯。我們也知道,其實在Tens

tensorflow 訓練模型的儲存 讀取儲存的模型進行測試

在實際中,通常需要將經過大量訓練的較好模型引數儲存起來,在實際應用以訓練好的模型進行預測。 TensorFlow中提供了模型儲存的模組 tensorflow.train.Saver() 1. 匯入tensorflow模組        

TensorFlowtfrecords檔案詳細教程

  歡迎關注“勇敢AI”公眾號,更多python學習、資料分析、機器學習、深度學習原創文章與大家分享,還有更多電子資源、教程、資料集下載。勇敢AI,一個專注於人工智慧AI的公眾號。 ==========================================

Tensorflow+MNIST+CNN+模型儲存讀取

# coding: utf-8 import tensorflow as tf import numpy as np from utils import * from tensorflow.examples.tutorials.mnist import inpu

sklearn、TensorFlow、keras模型儲存讀取

一、sklearn模型儲存與讀取 1、儲存 from sklearn.externals import joblib from sklearn import svm X = [[0, 0], [1, 1]] y = [0, 1] clf = svm.SVC(

csv檔案儲存讀取

csv是指逗號分隔值檔案格式,是以純文字形式來儲存表格資料。1.儲存為csv格式 把要儲存的表格中資料,按每一行存成一條記錄,然後將記錄新增到一個list裡面,寫入檔案的時候將每個記錄轉換成一串逗號分

json檔案儲存讀取

一、將資料儲存為.json檔案 1 model={} #資料 2 with open("./hmm.json",'w',encoding='utf-8') as json_file: 3 json.dump(model,json_file,ensure_a

演算法第二彈-線性規劃投資的收益和風險分析程式碼實現

嘻嘻,當你看到這篇文章的時候,是的,我又更新啦!初學小白的我對於如宇宙黑洞般一樣神祕的演算法,分享給大家一些小小的感受,第二彈! 演算法和數學息息相關,巧妙的邏輯往往建立在數學上,今天結合線性規劃,談談投資中收益與風險等因素的相互關係,為決策者在實際生活中提供參考依據,以及

javaee實戰專案--農貿產品開發(詳細講解程式碼實現

這個專案是我第一個與他人合作完成的專案,我用java寫後臺,綜合性較強,用到了hibernate,jsp,servlet,適合那些javaee已經入門,想要通過專案來提高自己的朋友。 這個專案原始碼免積分提供大家學習,還有一些模組的程式碼(如數字驗證碼),可以稍作修改就能

TensorFlow模型檔案儲存讀取

一、模型檔案的儲存 在訓練一個TensorFlow模型之後,我們可以將訓練好的模型儲存成檔案,這樣可以方便下一次對新的資料進行預測的時候直接載入訓練好的模型即可獲得結果,下面通過TensorFlow提供的tf.train.Saver函式,將一個模型儲存成檔案,一般習慣性的將TensorFlow的模

有關於tensorflow的.TFRecords 檔案怎麼樣來生成和讀取操作

下面將介紹如何生成和讀取tfrecords檔案: 首先介紹tfrecords檔案的生成,直接上程式碼: from random import shuffle import numpy as np import glob import tensorflow as tf

二進位制檔案操作-儲存讀取

目標:實現資料儲存為二進位制檔案,然後通過二進位制檔案解析資料。 目標分為三個階段:1、將資料結構轉化為二進位制(至於資料是怎樣讀取進來的這個就不說了,因為方式比較多,但是讀取進來一定都會以特定的資料結構形式來儲存)。2、載入二進位制文字。3、載入為對應的資料結構。 階段

[Deep Learning] TensorFlow模型、引數的儲存讀取

大部分情況,我們會把訓練的網路儲存下來,用於後面的使用。或者,在當前網路下對引數進行一定程度的微調。 儲存變數 import tensorflow as tf # The file path to save the data save_file = './mode

tensorflow中模型資料的儲存讀取

使用tensorflow過程中,訓練結束後我們需要用到模型檔案。有時候,我們可能也需要用到別人訓練好的模型,並在這個基礎上再次訓練。這時候我們需要掌握如何操作這些模型資料。看完本文,相信你一定會有收穫! 1 Tensorflow模型檔案 我們在checkpoint_dir

Android採用SharedPreferences方式進行檔案儲存讀取

工程目錄: 做軟體開發應該都知道,很多軟體會有配置檔案,裡面存放這程式運行當中的各個屬性值,由於其配置資訊並不多,如果採用資料庫來存放並不划算,因為資料庫連線跟操作等耗時大大影響了程式的效率,因此我們使用鍵值這種一一對應的關係來存放這些配置資訊。SharedPrefe

localStorage 儲存讀取

element ora onclick win nbsp n) cti item fun window.onload=function(){ /** *其實也沒那麽難, *localStorage.setItem(json) 向本地儲存值

tensorflowtf.nn.l2_normalizel2_loss的計算

1.tf.nn.l2_normalize   tf.nn.l2_normalize(x, dim, epsilon=1e-12, name=None)  上式:  x為輸入的向量;  dim為l2範化的維數,dim取值為0或0或1;  eps

預處理後資料的儲存讀取

在機器學習中,一般都需要先對資料進行資料預處理工作。模型一般需要反覆的調參,因此可能需要多次使用預處理之後的資料,但是反覆進行資料的預處理工作是多餘的,我們可以將其儲存下來。 #用pickle模組將處理好的資料儲存成pickle格式,方便以後呼叫,即建立一個checkpoint # 儲存資料方便呼叫

Android複習旅--檔案儲存

內部儲存 內部儲存是指將應用程式中的資料以檔案方式儲存到裝置的內部儲存空間中(該檔案位於 data/data// 目錄下)。 一般情況下應用儲存在記憶體下的資料其他應用是訪問不了的,當您希望確保使用者或其他應用均無法訪問您的檔案時,內部儲存是最佳選擇。使用者解除安裝