1. 程式人生 > >keras中的model.fit和model.fit_generator

keras中的model.fit和model.fit_generator

fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

本函式用以訓練模型,引數有:

  • x:輸入資料。如果模型只有一個輸入,那麼x的型別是numpy array,如果模型有多個輸入,那麼x的型別應當為list,list的元素是對應於各個輸入的numpy array。如果模型的每個輸入都有名字,則可以傳入一個字典,將輸入名與其輸入資料對應起來。

  • y:標籤,numpy array。如果模型有多個輸出,可以傳入一個numpy array的list。如果模型的輸出擁有名字,則可以傳入一個字典,將輸出名與其標籤對應起來。

  • batch_size:整數,指定進行梯度下降時每個batch包含的樣本數。訓練時一個batch的樣本會被計算一次梯度下降,使目標函式優化一步。

  • epochs:整數,訓練終止時的epoch值,訓練將在達到該epoch值時停止,當沒有設定initial_epoch時,它就是訓練的總輪數,否則訓練的總輪數為epochs - inital_epoch

  • verbose:日誌顯示,0為不在標準輸出流輸出日誌資訊,1為輸出進度條記錄,2為每個epoch輸出一行記錄

  • callbacks:list,其中的元素是keras.callbacks.Callback的物件。這個list中的回撥函式將會在訓練過程中的適當時機被呼叫,參考回撥函式

  • validation_split:0~1之間的浮點數,用來指定訓練集的一定比例資料作為驗證集。驗證集將不參與訓練,並在每個epoch結束後測試的模型的指標,如損失函式、精確度等。注意,validation_split的劃分在shuffle之後,因此如果你的資料本身是有序的,需要先手工打亂再指定validation_split,否則可能會出現驗證集樣本不均勻。

  • validation_data:形式為(X,y)或(X,y,sample_weights)的tuple,是指定的驗證集。此引數將覆蓋validation_spilt。

  • shuffle:布林值,表示是否在訓練過程中每個epoch前隨機打亂輸入樣本的順序。

  • class_weight:字典,將不同的類別對映為不同的權值,該引數用來在訓練過程中調整損失函式(只能用於訓練)。該引數在處理非平衡的訓練資料(某些類的訓練樣本數很少)時,可以使得損失函式對樣本數不足的資料更加關注。

  • sample_weight:權值的numpy array,用於在訓練時調整損失函式(僅用於訓練)。可以傳遞一個1D的與樣本等長的向量用於對樣本進行1對1的加權,或者在面對時序資料時,傳遞一個的形式為(samples,sequence_length)的矩陣來為每個時間步上的樣本賦不同的權。這種情況下請確定在編譯模型時添加了sample_weight_mode='temporal'

  • initial_epoch: 從該引數指定的epoch開始訓練,在繼續之前的訓練時有用。

  • steps_per_epoch: 一個epoch包含的步數(每一步是一個batch的資料送入),當使用如TensorFlow資料Tensor之類的輸入張量進行訓練時,預設的None代表自動分割,即資料集樣本數/batch樣本數。

  • validation_steps: 僅當steps_per_epoch被指定時有用,在驗證集上的step總數。

輸入資料與規定資料不匹配時會丟擲錯誤

fit函式返回一個History的物件,其History.history屬性記錄了損失函式和其他指標的數值隨epoch變化的情況,如果有驗證集的話,也包含了驗證集的這些指標變化情況。

 

 

fit_generator(self, generator, steps_per_epoch, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_q_size=10, workers=1, pickle_safe=False, initial_epoch=0)

 

函式的引數是:

  • generator:生成器函式,生成器的輸出應該為:

    • 一個形如(inputs,targets)的tuple

    • 一個形如(inputs, targets,sample_weight)的tuple。所有的返回值都應該包含相同數目的樣本。生成器將無限在資料集上迴圈。每個epoch以經過模型的樣本數達到samples_per_epoch時,記一個epoch結束

  • steps_per_epoch:整數,當生成器返回steps_per_epoch次資料時計一個epoch結束,執行下一個epoch

  • epochs:整數,資料迭代的輪數

  • verbose:日誌顯示,0為不在標準輸出流輸出日誌資訊,1為輸出進度條記錄,2為每個epoch輸出一行記錄

  • validation_data:具有以下三種形式之一

    • 生成驗證集的生成器

    • 一個形如(inputs,targets)的tuple

    • 一個形如(inputs,targets,sample_weights)的tuple

  • validation_steps: 當validation_data為生成器時,本引數指定驗證集的生成器返回次數

  • class_weight:規定類別權重的字典,將類別對映為權重,常用於處理樣本不均衡問題。

  • sample_weight:權值的numpy array,用於在訓練時調整損失函式(僅用於訓練)。可以傳遞一個1D的與樣本等長的向量用於對樣本進行1對1的加權,或者在面對時序資料時,傳遞一個的形式為(samples,sequence_length)的矩陣來為每個時間步上的樣本賦不同的權。這種情況下請確定在編譯模型時添加了sample_weight_mode='temporal'

  • workers:最大程序數

  • max_q_size:生成器佇列的最大容量

  • pickle_safe: 若為真,則使用基於程序的執行緒。由於該實現依賴多程序,不能傳遞non picklable(無法被pickle序列化)的引數到生成器中,因為無法輕易將它們傳入子程序中。

  • initial_epoch: 從該引數指定的epoch開始訓練,在繼續之前的訓練時有用。

函式返回一個History物件。

參考文獻: Keras中文文件