1. 程式人生 > >pytorch學習筆記(六):自定義Datasets

pytorch學習筆記(六):自定義Datasets

什麼是Datasets:

輸入流水線中,我們看到準備資料的程式碼是這麼寫的data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)datasets.CIFAR10就是一個Datasets子類,data是這個類的一個例項。

為什麼要定義Datasets:

PyTorch提供了一個工具函式torch.utils.data.DataLoader。通過這個類,我們在準備mini-batch的時候可以多執行緒並行處理,這樣可以加快準備資料的速度。Datasets就是構建這個類的例項的引數之一。

如何自定義Datasets

下面是一個自定義Datasets的框架

class CustomDataset(data.Dataset):#需要繼承data.Dataset
    def __init__(self):
        # TODO
        # 1. Initialize file path or list of file names.
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #這裡需要注意的是,第一步:read one data,是一個data pass def __len__(self): # You should change 0 to the total size of your dataset. return 0

下面看一下官方MNIST的例子(程式碼被縮減,只留下了重要的部分):

class MNIST(data.Dataset):
    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return 60000
        else:
            return 10000