損失函式震盪不收斂可能原因:tf.train.shuffle_batch
阿新 • • 發佈:2018-11-08
在製作tfrecords資料集的時候,比如說將cifar資料轉換成tfrecords資料集,一般會用到tf.train.shuffle_batch函式,而損失函式震盪不收斂的原因就可能就是資料集製作出現問題。
Cifar-10資料集包含了airlane、automobile、bird、cat、deer、dog、frog、horse、ship、truck,10種分類 ,分別放在十個資料夾中。共60000張圖片,其中訓練集50000張,測試集10000張。
開始在製作資料集的時候,我是先將一個資料夾中的所有圖片寫入tfrecords,這樣製作的問題就是:將同一類的圖片按照順序寫入到了tfrecords中,然而後面再讀取tfrecords時,使用到了tf.train.shuffle_batch,此函式只是在batch_size中reshuffle,總體的順序並沒有改變,所以喂入網路的資料都是同一類的圖片,並不能起到訓練網路的效果。
正確的做法是:應該將這些圖片打亂之後寫入到tfrecords中。我採取的方法是:因為每個資料夾中圖片數量是固定的,所以將這些圖片名稱全部讀取出來,儲存到字典中,因為batch_size為200,所以依次從十個資料夾中讀取20張圖片寫入到tfrecords中,這樣再訓練的時候,取出的資料就不再會是同一類的圖片。
def write_tfRecord(tfRecordName, image_path, label_path):
writer = tf.python_io.TFRecordWriter(tfRecordName)
num_pic = 0
dirs = os.listdir(image_path)
# print(dirs)
contents = {}
for _dir in dirs:
temp_path = os.path.join(image_path, _dir)
temp = os.listdir(temp_path)
contents[_dir] = temp
# print(len(contents[_dir]))
# print(contents)
for i in range(int(len(contents[dirs[ 0]]) / 20)):
for index in range(len(dirs)):
for j in range(i*20, i*20+20):
ima_path = os.path.join(image_path, dirs[index], contents[dirs[index]][j])
img = Image.open(ima_path)
img_raw = img.tobytes()
labels = [0] * 10
labels[index] = 1
example = tf.train.Example(features=tf.train.Features(feature={
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels))
}))
writer.write(example.SerializeToString())
num_pic += 1
print ("the number of picture:", num_pic)
writer.close()
print("write tfrecord successful")