1. 程式人生 > >【python】pytorch中如何使用DataLoader對資料集進行批處理

【python】pytorch中如何使用DataLoader對資料集進行批處理

第一步:

我們要建立torch能夠識別的資料集型別(pytorch中也有很多現成的資料集型別,以後再說)。

首先我們建立兩個向量X和Y,一個作為輸入的資料,一個作為正確的結果:

    

隨後我們需要把X和Y組成一個完整的資料集,並轉化為pytorch能識別的資料集型別:

    

我們來看一下這些資料的資料型別:

     

可以看出我們把X和Y通過Data.TensorDataset() 這個函式拼裝成了一個數據集,資料集的型別是【TensorDataset】。

好了,第一步結束了

 


 

第二步:

就是把上一步做成的資料集放入Data.DataLoader中,可以生成一個迭代器,從而我們可以方便的進行批處理。

     

DataLoader中也有很多其他引數:

複製程式碼

dataset:Dataset型別,從其中載入資料 
batch_size:int,可選。每個batch載入多少樣本 
shuffle:bool,可選。為True時表示每個epoch都對資料進行洗牌 
sampler:Sampler,可選。從資料集中取樣樣本的方法。 
num_workers:int,可選。載入資料時使用多少子程序。預設值為0,表示在主程序中載入資料。 
collate_fn:callable,可選。 
pin_memory:bool,可選 
drop_last:bool,可選。True表示如果最後剩下不完全的batch,丟棄。False表示不丟棄。

複製程式碼

好了,第二步結束了,

 


 

第三步:

好啦,現在我們就可以愉快的用我們上面定義好的迭代器進行訓練啦。

在這裡我們利用print來模擬我們的訓練過程,即我們在這裡對搭建好的網路進行喂入。

     

輸出的結果是:

      

可以看到,我們一共訓練了所有的資料訓練了5次。資料中一共10組,我們設定的mini-batch是3,即每一次我們訓練網路的時候喂入3組資料,到了最後一次我們只有1組資料了,比mini-batch小,我們就僅輸出這一個。

此外,還可以利用python中的enumerate(),是對所有可以迭代的資料型別(含有很多東西的list等等)進行取操作的函式,用法如下:

      

 

好啦,結束。

轉載自:https://www.cnblogs.com/JeasonIsCoding/p/10168753.html