1. 程式人生 > >TensorFlow——訓練自己的資料——CIFAR10(一)資料準備

TensorFlow——訓練自己的資料——CIFAR10(一)資料準備

Reading Data
所用函式

def read_cifar10(data_dir, is_train, batch_size, shuffle):`
    Args:
        data_dir: the directory of CIFAR10
        is_train: boolen
        batch_size:
        shuffle:   #是否打亂順序   
    Returns:
        label: 1D tensor, tf.int32
        image: 4D tensor, [batch_size, height, width, 3
], tf.float32

變數宣告

img_width = 32
img_height = 32
img_depth = 3
label_bytes = 1
image_bytes = img_width*img_height*img_depth #32x32x3=3072

讀取資料

#將以下操作放在一個作用域內,使得tensorboard更美觀
    with tf.name_scope('input'):

        #如果是訓練資料,則將檔案的路徑賦值給filenames,注意到data的命名規律為data_batch_n.bin
        #os.path.join(“home”, "me", "mywork"),在linux返回“home/me/mywork",在windows上返回"home\me\mywork"
#好處是可以根據系統自動選擇正確的路徑分隔符"/"或"\" if is_train: filenames = [os.path.join(data_dir, '/data_batch_%d.bin' %ii) for ii in np.arange(1, 5)] else: filenames = [os.path.join(data_dir, '/test_batch.bin')] #產生一個佇列,因為使用的是二進位制,所以使用string_input_producer
#ps:之前貓狗大戰是label+img,用的是slice_input_producer filename_queue = tf.train.string_input_producer(filenames) #讀取資料,label_bytes=1,image_bytes=32*32*3=3072 reader = tf.FixedLengthRecordReader(label_bytes + image_bytes) key, value = reader.read(filename_queue) #對讀取到的資料解碼decode #ps:貓狗大戰的資料是img,用的解碼器是tf.image.decode_jpeg record_bytes = tf.decode_raw(value, tf.uint8) #data包含了label和image,所以通過slice切片,把他們分開,這裡切了個[0,1] label = tf.slice(record_bytes, [0], [label_bytes]) label = tf.cast(label, tf.int32) #切[1,3072] image_raw = tf.slice(record_bytes, [label_bytes], [image_bytes]) #將二進位制資料reshape為影象資料[0-depth,1-height,2-width]=[3,32,32] image_raw = tf.reshape(image_raw, [img_depth, img_height, img_width]) #轉換為[1-height,2-width,0-depth] image = tf.transpose(image_raw, (1,2,0)) # convert from D/H/W to H/W/D image = tf.cast(image, tf.float32) # # data argumentation,影象增強(裁剪、旋轉、縮放等),但據說效果不怎麼樣 # image = tf.random_crop(image, [24, 24, 3])# randomly crop the image size to 24 x 24 # image = tf.image.random_flip_left_right(image) # image = tf.image.random_brightness(image, max_delta=63) # image = tf.image.random_contrast(image,lower=0.2,upper=1.8) #歸一化操作從[0,255]到[-1,1] image = tf.image.per_image_standardization(image) #substract off the mean and divide by the variance #是否打亂順序 if shuffle: images, label_batch = tf.train.shuffle_batch( [image, label], batch_size = batch_size, num_threads= 16, capacity = 2000, #佇列的容量 min_after_dequeue = 1500)#佇列取出後的最小值 else: images, label_batch = tf.train.batch( [image, label], batch_size = batch_size, num_threads = 16, capacity= 2000) # return images, tf.reshape(label_batch, [batch_size]) ## ONE-HOT ,將label轉換成[1,0,0,0,0,0,0,0,0,0]的形式(第一個類為正確) n_classes = 10 label_batch = tf.one_hot(label_batch, depth= n_classes) return images, tf.reshape(label_batch, [batch_size, n_classes])

測試資料
把一個Batch顯示出來

import matplotlib.pyplot as plt
#這裡用自己的data路徑
data_dir = 'D:/Study/Python/Projects/CIFAR10/data'
BATCH_SIZE = 2  #一個batch兩張圖
image_batch, label_batch = read_cifar10(data_dir,
                                        is_train=True,
                                        batch_size=BATCH_SIZE, 
                                        shuffle=True)

with tf.Session() as sess:
    i = 0
    #用coord和threads監控佇列
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        while not coord.should_stop() and i<1:

            img, label = sess.run([image_batch, label_batch])

            # just test one batch
            for j in np.arange(BATCH_SIZE):
                print('label: %d' %label[j])
                plt.imshow(img[j,:,:,:])
                plt.show()
            i+=1

    except tf.errors.OutOfRangeError:
        print('done!')
    finally:
        coord.request_stop()
    coord.join(threads)

結果
歸一化和float型別導致圖片顯示失真
這裡寫圖片描述