1. 程式人生 > >使用tensorflow訓練自己的資料集(一)——製作資料集

使用tensorflow訓練自己的資料集(一)——製作資料集

使用tensorflow訓練自己的資料集—製作資料集

想記錄一下自己製作訓練集並訓練的過、希望踩過的坑能幫助後面入坑的人。
本次使用的訓練集的是kaggle中經典的貓狗大戰資料集(提取碼:ufz5)。因為本人筆記本配置很差還不是N卡所以把train的資料分成了訓練集和測試集並沒有使用原資料集中的test。在tensorflow中使用TFRecord格式餵給神經網路但是現在官方推薦使用tf.data但這個API還沒看所以還是使用了TFRecord。
檔案目錄
程式碼註釋還挺清楚就直接上程式碼了。

import os
import tensorflow as tf
from PIL import
Image # 源資料地址 cwd = 'C:/Users/Qigq/Desktop/P_Data/kaggle/train' # 生成record路徑及檔名 train_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/train.tfrecords" test_record_path = "C:/Users/Qigq/Desktop/P_Data/kaggle/ouputdata/test.tfrecords" # 分類 classes = {'cat','dog'} def _byteslist(value): """二進位制屬性""" return
tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) def _int64list(value): """整數屬性""" return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) def create_train_record(): """建立訓練集tfrecord""" writer = tf.python_io.TFRecordWriter(train_record_path) # 建立一個writer
NUM = 1 # 顯示建立過程(計數) for index, name in enumerate(classes): class_path = cwd + "/" + name + '/' l = int(len(os.listdir(class_path)) * 0.7) # 取前70%建立訓練集 for img_name in os.listdir(class_path)[:l]: img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) # resize圖片大小 img_raw = img.tobytes() # 將圖片轉化為原生bytes example = tf.train.Example( # 封裝到Example中 features=tf.train.Features(feature={ "label":_int64list(index), # label必須為整數型別屬性 'img_raw':_byteslist(img_raw) # 圖片必須為二進位制屬性 })) writer.write(example.SerializeToString()) print('Creating train record in ',NUM) NUM += 1 writer.close() # 關閉writer print("Create train_record successful!") def create_test_record(): """建立測試tfrecord""" writer = tf.python_io.TFRecordWriter(test_record_path) NUM = 1 for index, name in enumerate(classes): class_path = cwd + '/' + name + '/' l = int(len(os.listdir(class_path)) * 0.7) for img_name in os.listdir(class_path)[l:]: # 剩餘30%作為測試集 img_path = class_path + img_name img = Image.open(img_path) img = img.resize((128, 128)) img_raw = img.tobytes() # 將圖片轉化為原生bytes # print(index,img_raw) example = tf.train.Example( features=tf.train.Features(feature={ "label":_int64list(index), 'img_raw':_byteslist(img_raw) })) writer.write(example.SerializeToString()) print('Creating test record in ',NUM) NUM += 1 writer.close() print("Create test_record successful!") def read_record(filename): """讀取tfrecord""" filename_queue = tf.train.string_input_producer([filename]) # 建立檔案佇列 reader = tf.TFRecordReader() # 建立reader _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'label': tf.FixedLenFeature([], tf.int64), 'img_raw': tf.FixedLenFeature([], tf.string) } ) label = features['label'] img = features['img_raw'] img = tf.decode_raw(img, tf.uint8) img = tf.reshape(img, [128, 128, 3]) img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 歸一化 label = tf.cast(label, tf.int32) return img, label def get_batch_record(filename,batch_size): """獲取batch""" image,label = read_record(filename) image_batch,label_batch = tf.train.shuffle_batch([image,label], # 隨機抽取batch size個image、label batch_size=batch_size, capacity=2000, min_after_dequeue=1000) return image_batch,label_batch def main(): create_train_record() create_test_record() if __name__ == '__main__': main() ### 呼叫示例 ### # create_train_record(cwd,classes) # create_test_record(cwd,classes) # image_batch,label_batch = get_batch_record(filename,32) # init = tf.initialize_all_variables() # # with tf.Session() as sess: # sess.run(init) # # coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(sess=sess,coord=coord) # # for i in range(1): # image,label = sess.run([image_batch,label_batch]) # print(image.shape,1) # # # coord.request_stop() # coord.join(threads)

下一篇將介紹定義神經網路
如有錯誤望多多指教~~