1. 程式人生 > >使用tfrecord建立自己的數據集

使用tfrecord建立自己的數據集

解碼 res bytes slist 關於 error font 需要 orm

註意事項:

1.關於輸入圖像格式的問題

使用io.imread()的時,根據輸入圖像確定as_grey的參數值。 轉化為字符串之後(image.tostring) ,最後輸出看下image_raw的長度。因為不同的圖像編碼格式,存儲方式不同。

我讀入的灰度圖jpeg格式,類型是int64,image_raw的大小是圖像的大小的8倍 。 但如果是RGB圖像,則統一類型是uint8。確定了類型,在之後的解碼 (decode_raw)中,需要將type設置和存儲方式同樣的類型。

根據image_raw的長度和原圖像大小,推算一下使用的類型,常用的是uint8,int32,int64.

2.轉化成tfrecords的時間有點長,需要等待。

import os
import tensorflow as tf
import numpy as np
import skimage.io as io
import matplotlib.pyplot as plt
import cv2
def get_data (file_path):
    data = []
    label = []
    for dirs in os.listdir(file_path):
        temp_path = os.path.join(file_path,dirs)
        i 
=0 for dirss in os.listdir(temp_path): data.append(os.path.join(temp_path,dirss)) num_img = len(os.listdir(temp_path)) label = np.append(label,num_img*[1]) temp = np.array([data,label]) temp = temp.transpose() np.random.shuffle(temp) image_list = list(temp[:,0]) label_list
= list(temp[:,1]) label_list = [int(float(i)) for i in label_list] return image_list,label_list # 轉化成字符串 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def convert_tfrecord(images,labels,save_filename): writer = tf.python_io.TFRecordWriter(save_filename) print("Transform start....") num_examples= len(labels) if np.shape(images)[0]!=num_examples: raise ValueError(Images size %d does not match label size %d. % (images.shape[0], num_examples)) for index in np.arange(0,num_examples): try: image = io.imread(images[index],as_grey=False) #image = tf.image.decode_jpeg(images[index]) #print(image.shape) image_raw = image.tostring() #print(len(image_raw)) example = tf.train.Example(features = tf.train.Features(feature={ label :_int64_feature(int(labels[index])), image_raw:_bytes_feature(image_raw) })) writer.write(example.SerializeToString()) except IOError as e: print(Could not read:,images[index]) print(error :%s Skip it !\n%e) writer.close() print("success!") def read_and_decode(tfrecords_file,batch_size): reader = tf.TFRecordReader() filename_queue = tf.train.string_input_producer([tfrecords_file]) _,serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ label: tf.FixedLenFeature([],tf.int64), image_raw: tf.FixedLenFeature([], tf.string) } ) #print(features[‘image_raw‘]) capacity = 1000+3*batch_size image = tf.decode_raw(features[image_raw],tf.uint8) label = tf.cast(features[label],tf.int32) #image = tf.image.resize_images(image,[300, 200, 1]) image = tf.reshape(image,[200,300,3]) image_batch,label_batch = tf.train.batch([image,label], batch_size=batch_size, capacity=capacity) image_batch = tf.image.resize_image_with_crop_or_pad(image_batch,100,100) image_batch = tf.cast(image_batch, tf.float32) * (1. / 255) return image_batch,label_batch def plot_images(images, labels): ‘‘‘plot one batch size ‘‘‘ for i in np.arange(0, 2): plt.subplot(3, 3, i + 1) plt.axis(off) # plt.title((labels[i] - 1), fontsize = 14) plt.subplots_adjust(top=1) print(labels[i]) print(images.shape) # print(images[i].shape) plt.imshow(images[i][:,:,:]) plt.show() def train(): image,label = get_data(‘E:\syn_data‘) convert_tfrecord(image,label,‘1.tfrecords‘) x_batch, y_batch = read_and_decode(1.tfrecords, batch_size=2) with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: i=0 while not coord.should_stop() and i<3: # just plot one batch size image, label = sess.run([x_batch, y_batch]) plot_images(image, label) i+=1 except tf.errors.OutOfRangeError: print(done!) finally: coord.request_stop() coord.join(threads) #train()

使用tfrecord建立自己的數據集