1. 程式人生 > >使用tensorflow訓練自己的資料集(一)

使用tensorflow訓練自己的資料集(一)

使用tensorflow訓練自己的資料集

想記錄一下自己製作訓練集並訓練的過、希望踩過的坑能幫助後面入坑的人。 本次使用的訓練集的是kaggle中經典的貓狗大戰資料集(提取碼:ufz5)。因為本人筆記本配置很差還不是N卡所以把train的資料分成了訓練集和測試集並沒有使用原資料集中的test。 檔案目錄 程式碼註釋還挺清楚就直接上程式碼了。

import os
import tensorflow as tf
from PIL import Image
# 源資料地址
cwd = 'C:/Users/Qigq/Desktop/P_Data/kaggle/train'
# 生成record路徑及檔名
train_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/train.tfrecords"
test_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/test.tfrecords"
# 分類
classes = {'cat','dog'}

def _byteslist(value):
    """二進位制屬性"""
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))

def _int64list(value):
    """整數屬性"""
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))

def create_train_record():
    """建立訓練集tfrecord"""
    writer = tf.python_io.TFRecordWriter(train_record_path)     # 建立一個writer
    NUM = 1                                     # 顯示建立過程(計數)
    for index, name in enumerate(classes):
        class_path = cwd + "/" + name + '/'
        l = int(len(os.listdir(class_path)) * 0.7)      # 取前70%建立訓練集
        for img_name in os.listdir(class_path)[:l]:
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((128, 128))                # resize圖片大小
            img_raw = img.tobytes()                     # 將圖片轉化為原生bytes
            example = tf.train.Example(                 # 封裝到Example中
                features=tf.train.Features(feature={
                    "label":_int64list(index),          # label必須為整數型別屬性
                    'img_raw':_byteslist(img_raw)       # 圖片必須為二進位制屬性
                }))
            writer.write(example.SerializeToString())
            print('Creating train record in ',NUM)
            NUM += 1
    writer.close()                                      # 關閉writer
    print("Create train_record successful!")

def create_test_record():
    """建立測試tfrecord"""
    writer = tf.python_io.TFRecordWriter(test_record_path)
    NUM = 1
    for index, name in enumerate(classes):
        class_path = cwd + '/' + name + '/'
        l = int(len(os.listdir(class_path)) * 0.7)
        for img_name in os.listdir(class_path)[l:]:     # 剩餘30%作為測試集
            img_path = class_path + img_name
            img = Image.open(img_path)
            img = img.resize((128, 128))
            img_raw = img.tobytes()  # 將圖片轉化為原生bytes
            # print(index,img_raw)
            example = tf.train.Example(
                features=tf.train.Features(feature={
                    "label":_int64list(index),
                    'img_raw':_byteslist(img_raw)
                }))
            writer.write(example.SerializeToString())
            print('Creating test record in ',NUM)
            NUM += 1
    writer.close()
    print("Create test_record successful!")

def read_record(filename):
    """讀取tfrecord"""
    filename_queue = tf.train.string_input_producer([filename])     # 建立檔案佇列
    reader = tf.TFRecordReader()                                    # 建立reader
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(
        serialized_example,
        features={
            'label': tf.FixedLenFeature([], tf.int64),
            'img_raw': tf.FixedLenFeature([], tf.string)
        }
    )
    label = features['label']
    img = features['img_raw']
    img = tf.decode_raw(img, tf.uint8)
    img = tf.reshape(img, [128, 128, 3])
    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5       # 歸一化
    label = tf.cast(label, tf.int32)
    return img, label

def get_batch_record(filename,batch_size):
    """獲取batch"""
    image,label = read_record(filename)
    image_batch,label_batch = tf.train.shuffle_batch([image,label],         # 隨機抽取batch size個image、label
                                                     batch_size=batch_size,
                                                     capacity=2000,
                                                     min_after_dequeue=1000)
    return image_batch,label_batch

def main():
    create_train_record()
    create_test_record()
if __name__ == '__main__':
    main()

                                ### 呼叫示例 ###
# create_train_record(cwd,classes)
# create_test_record(cwd,classes)
# image_batch,label_batch = get_batch_record(filename,32)
# init = tf.initialize_all_variables()
#
# with tf.Session() as sess:
#     sess.run(init)
#
#     coord = tf.train.Coordinator()
#     threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#
#     for i in range(1):
#         image,label = sess.run([image_batch,label_batch])
#         print(image.shape,1)
#
#
#     coord.request_stop()
#     coord.join(threads)

如有錯誤望多多指教~~