TensorFlow——訓練自己的資料——CIFAR10(一)資料準備
阿新 • • 發佈:2019-01-24
Reading Data
所用函式
def read_cifar10(data_dir, is_train, batch_size, shuffle):`
Args:
data_dir: the directory of CIFAR10
is_train: boolen
batch_size:
shuffle: #是否打亂順序
Returns:
label: 1D tensor, tf.int32
image: 4D tensor, [batch_size, height, width, 3 ], tf.float32
變數宣告
img_width = 32
img_height = 32
img_depth = 3
label_bytes = 1
image_bytes = img_width*img_height*img_depth #32x32x3=3072
讀取資料
#將以下操作放在一個作用域內,使得tensorboard更美觀
with tf.name_scope('input'):
#如果是訓練資料,則將檔案的路徑賦值給filenames,注意到data的命名規律為data_batch_n.bin
#os.path.join(“home”, "me", "mywork"),在linux返回“home/me/mywork",在windows上返回"home\me\mywork"
#好處是可以根據系統自動選擇正確的路徑分隔符"/"或"\"
if is_train:
filenames = [os.path.join(data_dir, '/data_batch_%d.bin' %ii)
for ii in np.arange(1, 5)]
else:
filenames = [os.path.join(data_dir, '/test_batch.bin')]
#產生一個佇列,因為使用的是二進位制,所以使用string_input_producer
#ps:之前貓狗大戰是label+img,用的是slice_input_producer
filename_queue = tf.train.string_input_producer(filenames)
#讀取資料,label_bytes=1,image_bytes=32*32*3=3072
reader = tf.FixedLengthRecordReader(label_bytes + image_bytes)
key, value = reader.read(filename_queue)
#對讀取到的資料解碼decode
#ps:貓狗大戰的資料是img,用的解碼器是tf.image.decode_jpeg
record_bytes = tf.decode_raw(value, tf.uint8)
#data包含了label和image,所以通過slice切片,把他們分開,這裡切了個[0,1]
label = tf.slice(record_bytes, [0], [label_bytes])
label = tf.cast(label, tf.int32)
#切[1,3072]
image_raw = tf.slice(record_bytes, [label_bytes], [image_bytes])
#將二進位制資料reshape為影象資料[0-depth,1-height,2-width]=[3,32,32]
image_raw = tf.reshape(image_raw, [img_depth, img_height, img_width])
#轉換為[1-height,2-width,0-depth]
image = tf.transpose(image_raw, (1,2,0)) # convert from D/H/W to H/W/D
image = tf.cast(image, tf.float32)
# # data argumentation,影象增強(裁剪、旋轉、縮放等),但據說效果不怎麼樣
# image = tf.random_crop(image, [24, 24, 3])# randomly crop the image size to 24 x 24
# image = tf.image.random_flip_left_right(image)
# image = tf.image.random_brightness(image, max_delta=63)
# image = tf.image.random_contrast(image,lower=0.2,upper=1.8)
#歸一化操作從[0,255]到[-1,1]
image = tf.image.per_image_standardization(image) #substract off the mean and divide by the variance
#是否打亂順序
if shuffle:
images, label_batch = tf.train.shuffle_batch(
[image, label],
batch_size = batch_size,
num_threads= 16,
capacity = 2000, #佇列的容量
min_after_dequeue = 1500)#佇列取出後的最小值
else:
images, label_batch = tf.train.batch(
[image, label],
batch_size = batch_size,
num_threads = 16,
capacity= 2000)
# return images, tf.reshape(label_batch, [batch_size])
## ONE-HOT ,將label轉換成[1,0,0,0,0,0,0,0,0,0]的形式(第一個類為正確)
n_classes = 10
label_batch = tf.one_hot(label_batch, depth= n_classes)
return images, tf.reshape(label_batch, [batch_size, n_classes])
測試資料
把一個Batch顯示出來
import matplotlib.pyplot as plt
#這裡用自己的data路徑
data_dir = 'D:/Study/Python/Projects/CIFAR10/data'
BATCH_SIZE = 2 #一個batch兩張圖
image_batch, label_batch = read_cifar10(data_dir,
is_train=True,
batch_size=BATCH_SIZE,
shuffle=True)
with tf.Session() as sess:
i = 0
#用coord和threads監控佇列
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
try:
while not coord.should_stop() and i<1:
img, label = sess.run([image_batch, label_batch])
# just test one batch
for j in np.arange(BATCH_SIZE):
print('label: %d' %label[j])
plt.imshow(img[j,:,:,:])
plt.show()
i+=1
except tf.errors.OutOfRangeError:
print('done!')
finally:
coord.request_stop()
coord.join(threads)
結果
歸一化和float型別導致圖片顯示失真