1. 程式人生 > >損失函式震盪不收斂可能原因:tf.train.shuffle_batch

損失函式震盪不收斂可能原因:tf.train.shuffle_batch

​ 在製作tfrecords資料集的時候,比如說將cifar資料轉換成tfrecords資料集,一般會用到tf.train.shuffle_batch函式,而損失函式震盪不收斂的原因就可能就是資料集製作出現問題。

​ Cifar-10資料集包含了airlane、automobile、bird、cat、deer、dog、frog、horse、ship、truck,10種分類 ,分別放在十個資料夾中。共60000張圖片,其中訓練集50000張,測試集10000張。

在這裡插入圖片描述

​ 開始在製作資料集的時候,我是先將一個資料夾中的所有圖片寫入tfrecords,這樣製作的問題就是:將同一類的圖片按照順序寫入到了tfrecords中,然而後面再讀取tfrecords時,使用到了tf.train.shuffle_batch,此函式只是在batch_size中reshuffle,總體的順序並沒有改變,所以喂入網路的資料都是同一類的圖片,並不能起到訓練網路的效果。

​ 正確的做法是:應該將這些圖片打亂之後寫入到tfrecords中。我採取的方法是:因為每個資料夾中圖片數量是固定的,所以將這些圖片名稱全部讀取出來,儲存到字典中,因為batch_size為200,所以依次從十個資料夾中讀取20張圖片寫入到tfrecords中,這樣再訓練的時候,取出的資料就不再會是同一類的圖片。

def write_tfRecord(tfRecordName, image_path, label_path):
    writer = tf.python_io.TFRecordWriter(tfRecordName)  
    num_pic = 0 
    dirs =
os.listdir(image_path) # print(dirs) contents = {} for _dir in dirs: temp_path = os.path.join(image_path, _dir) temp = os.listdir(temp_path) contents[_dir] = temp # print(len(contents[_dir])) # print(contents) for i in range(int(len(contents[dirs[
0]]) / 20)): for index in range(len(dirs)): for j in range(i*20, i*20+20): ima_path = os.path.join(image_path, dirs[index], contents[dirs[index]][j]) img = Image.open(ima_path) img_raw = img.tobytes() labels = [0] * 10 labels[index] = 1 example = tf.train.Example(features=tf.train.Features(feature={ 'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])), 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=labels)) })) writer.write(example.SerializeToString()) num_pic += 1 print ("the number of picture:", num_pic) writer.close() print("write tfrecord successful")