1. 程式人生 > >pytorch中的torch.utils.data.Dataset和torch.utils.data.DataLoader

pytorch中的torch.utils.data.Dataset和torch.utils.data.DataLoader

首先看torch.utils.data.Dataset這個抽象類。可以使用這個抽象類來構造pytorch資料集。要注意的是以這個類構造的子類,一定要定義兩個函式一個是__len__,另一個是__getitem__,前者提供資料集size,而後者通過給定索引獲取資料和標籤。__getitem__一次只能獲取一個數據(不知道是不是強制性的),所以通過torch.utils.data.DataLoader來定義一個新的迭代器,實現batch讀取。首先我們來定義一個j簡單的資料集:

from torch.utils.data.dataset import Dataset
import numpy as np
class TxtDataset(Dataset):#這是一個Dataset子類
    def __init__(self):
        self.Data=np.asarray([[1,2],[3,4],[2,1],[3,4],[4,5]])#特徵向量集合,特徵是2維表示一段文字
        Label=np.asarray([1, 2, 0, 1, 2])#標籤是1維,表示文字類別

    def __getitem__(self, index):
        txt=torch.LongTensor(self.Data[index])
        label=torch.LongTensor(self.Label[index])
        return txt, label #返回標籤

    def __len__(self):
        return len(self.Data)

我們建立一個TxtDataset物件,並呼叫函式,注意__getitem__的呼叫要通過: 物件[索引]呼叫

Txt=TxtDataset()
print(Txt[1])
print(Txt.__len__())


#輸出:
(array([3, 4]), 2)
5

看到輸出中特徵向量和標籤是以tuple返回的。而此處得到樣本是一個不是批量的所以我們使用了torch.utils.data.DataLoader引數有 資料集物件(Dataset)、batc_size、shuffle(設定為真每個epoch會進行重置資料順序,一般在訓練資料中使用)、num_workers(設定多少個子程序可以使用,設定0表示在主程序中使用)

test_loader = DataLoader(Txt,batch_size=2,shuffle=False,
                          num_workers=4)
for i,traindata in enumerate(test_loader):
    print('i:',i)
    Data,Label=traindata
    print('data:',Data)
    print('Label:',Label)

輸出:
i: 0
data: tensor([[ 1,  2],
        [ 3,  4]], dtype=torch.int32)
Label: tensor([ 1,  2], dtype=torch.int32)
i: 1
data: tensor([[ 2,  1],
        [ 3,  4]], dtype=torch.int32)
Label: tensor([ 0,  1], dtype=torch.int32)
i: 2
data: tensor([[ 4,  5]], dtype=torch.int32)
Label: tensor([ 2], dtype=torch.int32)

在這個例子中設定批量為2,因此每次去出兩個樣本。除了文字資料可以這樣設定,圖片資料集也是可以的。