pytorch學習筆記(六):自定義Datasets
阿新 • • 發佈:2019-01-23
什麼是Datasets:
在輸入流水線中,我們看到準備資料的程式碼是這麼寫的data = datasets.CIFAR10("./data/", transform=transform, train=True, download=True)
。datasets.CIFAR10
就是一個Datasets
子類,data
是這個類的一個例項。
為什麼要定義Datasets:
PyTorch
提供了一個工具函式torch.utils.data.DataLoader
。通過這個類,我們在準備mini-batch
的時候可以多執行緒並行處理,這樣可以加快準備資料的速度。Datasets
就是構建這個類的例項的引數之一。
如何自定義Datasets
下面是一個自定義Datasets的框架
class CustomDataset(data.Dataset):#需要繼承data.Dataset
def __init__(self):
# TODO
# 1. Initialize file path or list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
#這裡需要注意的是,第一步:read one data,是一個data
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
下面看一下官方MNIST
的例子(程式碼被縮減,只留下了重要的部分):
class MNIST(data.Dataset):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
if self.train:
self.train_data, self.train_labels = torch.load(
os.path.join(root, self.processed_folder, self.training_file))
else:
self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
def __getitem__(self, index):
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
if self.train:
return 60000
else:
return 10000