使用tensorflow訓練自己的資料集(一)——製作資料集
阿新 • • 發佈:2018-11-12
使用tensorflow訓練自己的資料集—製作資料集
想記錄一下自己製作訓練集並訓練的過、希望踩過的坑能幫助後面入坑的人。
本次使用的訓練集的是kaggle中經典的貓狗大戰資料集(提取碼:ufz5)。因為本人筆記本配置很差還不是N卡所以把train的資料分成了訓練集和測試集並沒有使用原資料集中的test。在tensorflow中使用TFRecord格式餵給神經網路但是現在官方推薦使用tf.data
但這個API還沒看所以還是使用了TFRecord。
程式碼註釋還挺清楚就直接上程式碼了。
import os
import tensorflow as tf
from PIL import Image
# 源資料地址
cwd = 'C:/Users/Qigq/Desktop/P_Data/kaggle/train'
# 生成record路徑及檔名
train_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/train.tfrecords"
test_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/test.tfrecords"
# 分類
classes = {'cat','dog'}
def _byteslist(value):
"""二進位制屬性"""
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
def _int64list(value):
"""整數屬性"""
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def create_train_record():
"""建立訓練集tfrecord"""
writer = tf.python_io.TFRecordWriter(train_record_path) # 建立一個writer
NUM = 1 # 顯示建立過程(計數)
for index, name in enumerate(classes):
class_path = cwd + "/" + name + '/'
l = int(len(os.listdir(class_path)) * 0.7) # 取前70%建立訓練集
for img_name in os.listdir(class_path)[:l]:
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128)) # resize圖片大小
img_raw = img.tobytes() # 將圖片轉化為原生bytes
example = tf.train.Example( # 封裝到Example中
features=tf.train.Features(feature={
"label":_int64list(index), # label必須為整數型別屬性
'img_raw':_byteslist(img_raw) # 圖片必須為二進位制屬性
}))
writer.write(example.SerializeToString())
print('Creating train record in ',NUM)
NUM += 1
writer.close() # 關閉writer
print("Create train_record successful!")
def create_test_record():
"""建立測試tfrecord"""
writer = tf.python_io.TFRecordWriter(test_record_path)
NUM = 1
for index, name in enumerate(classes):
class_path = cwd + '/' + name + '/'
l = int(len(os.listdir(class_path)) * 0.7)
for img_name in os.listdir(class_path)[l:]: # 剩餘30%作為測試集
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128))
img_raw = img.tobytes() # 將圖片轉化為原生bytes
# print(index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label":_int64list(index),
'img_raw':_byteslist(img_raw)
}))
writer.write(example.SerializeToString())
print('Creating test record in ',NUM)
NUM += 1
writer.close()
print("Create test_record successful!")
def read_record(filename):
"""讀取tfrecord"""
filename_queue = tf.train.string_input_producer([filename]) # 建立檔案佇列
reader = tf.TFRecordReader() # 建立reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [128, 128, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 歸一化
label = tf.cast(label, tf.int32)
return img, label
def get_batch_record(filename,batch_size):
"""獲取batch"""
image,label = read_record(filename)
image_batch,label_batch = tf.train.shuffle_batch([image,label], # 隨機抽取batch size個image、label
batch_size=batch_size,
capacity=2000,
min_after_dequeue=1000)
return image_batch,label_batch
def main():
create_train_record()
create_test_record()
if __name__ == '__main__':
main()
### 呼叫示例 ###
# create_train_record(cwd,classes)
# create_test_record(cwd,classes)
# image_batch,label_batch = get_batch_record(filename,32)
# init = tf.initialize_all_variables()
#
# with tf.Session() as sess:
# sess.run(init)
#
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(sess=sess,coord=coord)
#
# for i in range(1):
# image,label = sess.run([image_batch,label_batch])
# print(image.shape,1)
#
#
# coord.request_stop()
# coord.join(threads)
下一篇將介紹定義神經網路
如有錯誤望多多指教~~