1. 程式人生 > >PyTorch學習系列(一)——載入資料並生成batch資料

PyTorch學習系列(一)——載入資料並生成batch資料

開始學習PyTorch,在此記錄學習過程。準備按順序寫以下系列:

讀取資料生成並構建Dataset子類

假設現在已經實現從資料檔案中讀取輸入images和標記labels(列表),那麼怎麼根據images和labels定義自己的資料集類?答案是作為torch.utils.data.Dataset的子類。

torchvision.datasets中有幾個已經定義好的資料集類,這些類都是torch.utils.data.Dataset抽象類的子類:

在定義torch.utils.data.Dataset的子類時,必須過載的兩個函式是__len__和__getitem__。__len__返回資料集的大小,__getitem__實現資料集的下標索引,返回對應的影象和標記(不一定非得返回影象和標記,返回元組的長度可以是任意長,這由網路需要的資料決定)。
在建立DataLoader時會判斷__getitem__返回值的資料型別,然後用不同的if/else分支把資料轉換成tensor,所以,_getitem_返回值的資料型別可選擇範圍很多,一種可以選擇的資料型別是:影象為numpy.array,標記為int資料型別。
這裡寫圖片描述


示例:

from __future__ import print_function
import torch.utils.data as data
import torch

class MyDataset(data.Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __getitem__(self, index):#返回的是tensor
        img, target = self.images[index], self.labels[index]
        return
img, target def __len__(self): return len(self.images) dataset = MyDataset(images, labels)

生成batch資料

現在有了由資料檔案生成的結構資料MyDataset,那麼怎麼在訓練時提供batch資料呢?PyTorch提供了生成batch資料的類。

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

引數
dataset:Dataset型別,從其中載入資料
batch_size:int,可選。每個batch載入多少樣本
shuffle:bool,可選。為True時表示每個epoch都對資料進行洗牌
sampler:Sampler,可選。從資料集中取樣樣本的方法。
num_workers:int,可選。載入資料時使用多少子程序。預設值為0,表示在主程序中載入資料。
collate_fn:callable,可選。
pin_memory:bool,可選
drop_last:bool,可選。True表示如果最後剩下不完全的batch,丟棄。False表示不丟棄。

示例

kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
    MyDataset(images, labels), batch_size=args.batch_size, shuffle=True, **kwargs)

其他用法
len(train_loader) :返回的是len(dataset)/batch_size