1. 程式人生 > >tf.data.Dataset的一些小坑

tf.data.Dataset的一些小坑

我們使用資料的時候都是用batch來做輸入,使用tf.data.Dataset的時候,一般會這樣寫:

dataset = dataset.batch(batch_size).repeat(epochs)

用來說明我們需要對整個資料集進行多少個epochs,每次的輸入大小是多少個batch.
注意:
如果我們的資料集的數量為N,而N%batch_size剛好能整除的話,上述程式碼是沒有任何bug的,但如果整除不了,那麼在每個epoch的最後一個batch,其資料不再是batch_size個,而是N%batch_size個數據. 程式中如果有設定好tf.placeholder來修飾input,那麼程式在執行到最後一個batch的時候就會報錯,因為batch數量對不上. 所以,無論怎樣,一個比較好的程式碼習慣是這麼寫:

dataset = dataset.batch(batch_size, drop_remainder=True).repeat(epochs)

這樣程式會自動把最後一個不足batch_size的batch給忽略掉.