1. 程式人生 > >使用model.fit_generator方法進行訓練(自己的訓練集-多分類)

使用model.fit_generator方法進行訓練(自己的訓練集-多分類)

我們在使用model.fit()進行訓練的時候, 在這之前你肯定會有訓練集的x_img_train,y_label_train兩個引數。

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

但是當我們使用model.fit_generator()的時候,它的方法是這樣的:

fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)

可以看到它要求傳入的引數是一個generator.官網說的很清楚,(

不清楚的可以看官網)這裡的generator是一個生成器,主要是訓練自己的資料,並且資料非常多的時候可以不用把資料全部載入進記憶體,而是用生成器自己一點點讀取。大大提高的執行效率。

下面是這個生成器的生成方法:

#這是訓練集的生成器
train_datagen = ImageDataGenerator(
    rescale=1. / 255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

## 訓練圖片生成器
train_generator = train_datagen.flow_from_directory(
    train_data_dir,#訓練樣本地址
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical') #多分類

test_datagen = ImageDataGenerator(rescale=1. / 255)

##驗證集的生成器
validation_generator = test_datagen.flow_from_directory(
    validation_data_dir,#驗證樣本地址
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False) #多分類

好了,有了這個train_generator生成器我們就可以入入fit_generator(...)裡面進行訓練了。

對了,這裡說明下train_data_dir / validation_data_dir 是我本機的訓練集與驗證集的地址。

目錄結構形似:

'''
data/train/
          1/
             001.jpg
             002.jpg
             ...
          2/
            001.jpg
            002.jpg
            ...

data/validation/
                1/
                    001.jpg
                    002.jpg
                    ...
                2/
                    001.jpg
                    002.jpg
                    ...
            
'''