學習pytorch: 資料載入和處理
簡介
結合 ofollow,noindex">官方tutorials 和 原始碼 以及部分部落格寫出此文。
pytorch
的資料載入和處理相對容易的多,常見的兩種形式的匯入:
- 一種是整個資料集都在一個資料夾下,內部再另附一個label檔案,說明每個資料夾的狀態,如這個 資料庫 。這種存放資料的方式可能更適合在 非分類 問題上得到應用。
- 一種則是更適合使用在 分類 問題上,即把不同種類的資料分為不同的資料夾存放起來。其形式如下:
root/ants/xxx.pngroot/ants/xxy.jpegroot/ants/xxz.png...root/bees/123.jpgroot/bees/nsdf3.pngroot/bees/asd932_.png
本文首先結合 官方turorials 介紹 第一種 方法,以瞭解其資料載入的原理;然後以程式碼形式簡單介紹 第二種 方法。其中第二種方法和第一種方法的原理相同,其差別在於第二種方法運用了 trochvision
中提供的已寫好的工具 ImageFolder
,因此實現起來更為簡單。
第一種
Dataset class
torch.utils.data.Dataset
是一個抽象類,使用者想要載入自定義的資料只需要繼承這個類,並且覆寫其中的兩個方法即可:
-
__len__
:覆寫這個方法使得len(dataset)
可以返回整個資料集的大小 -
__getitem__
:覆寫這個方法使得dataset[i]
可以返回資料集中第i
個樣本 - 不覆寫這兩個方法會直接返回錯誤,其原始碼如下:
def __getitem__(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError
這裡我隨便從網上下載了20張影象,10張小貓,10張小狗。為了省事兒(只是想驗證下繼承 Dataset
類是否好用),我沒有給資料集增加標籤檔案,而是直接把1-10號定義為小貓,11-20號定義為小狗,這樣會給 __len__
和 __getitem__
減小麻煩,其目錄結構如下:

目錄結構
建立的自定義類如下:
from torch.utils.data import DataLoader, Dataset from skimage import io, transform import matplotlib.pyplot as plt import os import torch from torchvision import transforms import numpy as np class AnimalData(Dataset): def __init__(self, root_dir, transform=None): self.root_dir = root_dir self.transform = transform def __len__(self): return 20 def __getitem__(self, idx): filenames = os.listdir(self.root_dir) filename = filenames[idx] img = io.imread(os.path.join(self.root_dir, filename)) # print filename[:-5] if (int(filename[:-5]) > 10): lable = np.array([0]) else: lable = np.array([1]) sample = {'image': img, 'lable':lable} if self.transform: sample = self.transform(sample) return sample
Transforms & Compose transforms
可以注意到上一節中 AnimalData
類中 __init__
中有個 transform
引數,這也是這一節中要講清楚的問題。
從網上隨便下載的圖片必然大小不一,而 cnn
的結構卻要求輸入影象要有固定的大小; numpy
中的影象通道定義為 H, W, C
,而 pytorch
中的通道定義為 C, H, W
; pytorch
中輸入資料需要將 numpy array
改為 tensor
型別;輸入資料往往需要歸一化,等等。
基於以上考慮,我們可以自定義一些 Callable
的類,然後作為 trasform
引數傳遞給上一節定義的 dataset
類。為了更加方便, torchvision.transforms.Compose
提供了Compose類,可以一次性將我們自定義的 callable
類傳遞給 dataset
類,直接得到轉換後的資料。
這裡我直接 copy
了 教程 上的三個類: Rescale
, RandomCrop
, ToTensor
,稍作改動,適應我的資料庫。
class Rescale(object): """Rescale the image in a sample to a given size. Args: output_size (tuple or int): Desired output size. If tuple, output is matched to output_size. If int, smaller of image edges is matched to output_size keeping aspect ratio the same. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) self.output_size = output_size def __call__(self, sample): image, lable = sample['image'], sample['lable'] h, w = image.shape[:2] if isinstance(self.output_size, int): if h > w: new_h, new_w = self.output_size * h / w, self.output_size else: new_h, new_w = self.output_size, self.output_size * w / h else: new_h, new_w = self.output_size new_h, new_w = int(new_h), int(new_w) img = transform.resize(image, (new_h, new_w)) # h and w are swapped for lable because for images, # x and y axes are axis 1 and 0 respectively # lable = lable * [new_w / w, new_h / h] return {'image': img, 'lable': lable} class RandomCrop(object): """Crop randomly the image in a sample. Args: output_size (tuple or int): Desired output size. If int, square crop is made. """ def __init__(self, output_size): assert isinstance(output_size, (int, tuple)) if isinstance(output_size, int): self.output_size = (output_size, output_size) else: assert len(output_size) == 2 self.output_size = output_size def __call__(self, sample): image, lable = sample['image'], sample['lable'] h, w = image.shape[:2] new_h, new_w = self.output_size top = np.random.randint(0, h - new_h) left = np.random.randint(0, w - new_w) image = image[top: top + new_h, left: left + new_w] # lable = lable - [left, top] return {'image': image, 'lable': lable} class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): image, lable = sample['image'], sample['lable'] # print lable # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) return {'image': torch.from_numpy(image), 'lable': torch.from_numpy(lable)}
定義好 callable
類之後,通過 torchvision.transforms.Compose
將上述三個類結合在一起,傳遞給 AnimalData
類中的 transform
引數即可。
trsm = transforms.Compose([Rescale(256), RandomCrop(224), ToTensor()]) data = AnimalData('./all', transform=trsm)
Iterating through the dataset
上一節中得到 data
例項之後可以通過 for
迴圈來一個一個讀取資料,現在這是效率低下的。 torch.utils.data.DadaLoader
類解決了上述問題。其主要有如下特點:
multiprocessing
實現起來也很簡單:
dataloader = DataLoader(data, batch_size=4, shuffle=True, num_workers=4) for i_batch, bach_data in enumerate(dataloader): print i_batch print bach_data['image'].size() print bach_data['lable']
第二種
torchvision
pytorch
幾乎將上述所有工作都封裝起來供我們使用,其中一個工具就是 torchvision.datasets.ImageFolder
,用於載入使用者自定義的資料,要求我們的資料要有如下結構:
root/ants/xxx.pngroot/ants/xxy.jpegroot/ants/xxz.png...root/bees/123.jpgroot/bees/nsdf3.pngroot/bees/asd932_.png
torchvision.transforms
中也封裝了各種各樣的資料處理的工具,如 Resize
, ToTensor
等等功能供我們使用。
修改我下載的資料庫結構如下:

method2_tree
載入資料程式碼如下:
from torchvision import transforms, utils from torchvision import datasets import torch import matplotlib.pyplot as plt train_data = datasets.ImageFolder('./data1', transform=transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor() ])) train_loader = torch.utils.data.DataLoader(train_data, batch_size=4, shuffle=True, ) print len(train_loader) for i_batch, img in enumerate(train_loader): if i_batch == 0: print(img[1]) fig = plt.figure() grid = utils.make_grid(img[0]) plt.imshow(grid.numpy().transpose((1, 2, 0))) plt.show() break
結果圖:

make_grid
附錄
最後欣賞一段 torchvision
原始碼
# vision/torchvision/datasets/folder.py import torch.utils.data as data from PIL import Image import os import os.path IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm'] def is_image_file(filename): """Checks if a file is an image. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx def make_dataset(dir, class_to_idx): images = [] dir = os.path.expanduser(dir) for target in sorted(os.listdir(dir)): d = os.path.join(dir, target) if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): for fname in sorted(fnames): if is_image_file(fname): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) return images def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def accimage_loader(path): import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path): from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(data.Dataset): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform thattakes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, transform=None, target_transform=None, loader=default_loader): classes, class_to_idx = find_classes(root) imgs = make_dataset(root, class_to_idx) if len(imgs) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.classes = classes self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ path, target = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.imgs) def __repr__(self): fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += 'Number of datapoints: {}\n'.format(self.__len__()) fmt_str += 'Root Location: {}\n'.format(self.root) tmp = 'Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = 'Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str
參考
[1]. Data Loading and Processing Tutorial
[2]. github: pytorch/torch/utils/data/dataset.py
[3]. github: vision/torchvision/datasets/folder.py
[4]. csdn