1. 程式人生 > >PyTorch學習之路(level2)——自定義資料讀取

PyTorch學習之路(level2)——自定義資料讀取

在上一篇部落格PyTorch學習之路(level1)——訓練一個影象分類模型中介紹瞭如何用PyTorch訓練一個影象分類模型,建議先看懂那篇部落格後再看這篇部落格。在那份程式碼中,採用torchvision.datasets.ImageFolder這個介面來讀取影象資料,該介面預設你的訓練資料是按照一個類別存放在一個資料夾下。但是有些情況下你的影象資料不是這樣維護的,比如一個資料夾下面各個類別的影象資料都有,同時用一個對應的標籤檔案,比如txt檔案來維護影象和標籤的對應關係,在這種情況下就不能用torchvision.datasets.ImageFolder來讀取資料了,需要自定義一個數據讀取介面。

另外這篇部落格最後還順帶介紹如何儲存模型和多GPU訓練。

怎麼做呢?

先來看看torchvision.datasets.ImageFolder這個類是怎麼寫的,主要程式碼如下,想詳細瞭解的可以看:官方github程式碼

看起來很複雜,其實非常簡單。繼承的類是torch.utils.data.Dataset,主要包含三個方法:初始化__init__,獲取影象__getitem__,資料集數量 __len____init__方法中先通過find_classes函式得到分類的類別名(classes)和類別名與數字類別的對映關係字典(class_to_idx)。然後通過make_dataset函式得到imags,這個imags是一個列表,其中每個值是一個tuple,每個tuple包含兩個元素:影象路徑和標籤。剩下的就是一些賦值操作了。在__getitem__

方法中最重要的就是 img = self.loader(path)這行,表示資料讀取,可以從__init__方法中看出self.loader採用的是default_loader,這個default_loader的核心就是用python的PIL庫的Image模組來讀取影象資料。

class ImageFolder(data.Dataset):
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.

     Attributes:
        classes (list): List of the class names.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """
def __init__(self, root, transform=None, target_transform=None, loader=default_loader): classes, class_to_idx = find_classes(root) imgs = make_dataset(root, class_to_idx) if len(imgs) == 0: raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.classes = classes self.class_to_idx = class_to_idx self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is class_index of the target class. """ path, target = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.imgs)

稍微看下default_loader函式,該函式主要分兩種情況呼叫兩個函式,一般採用pil_loader函式。

def pil_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

看懂了ImageFolder這個類,就可以自定義一個你自己的資料讀取介面了。

首先在PyTorch中和資料讀取相關的類基本都要繼承一個基類:torch.utils.data.Dataset。然後再改寫其中的__init____len____getitem__等方法即可

下面假設img_path是你的影象資料夾,該資料夾下面放了所有影象資料(包括訓練和測試),然後txt_path下面放了train.txt和val.txt兩個檔案,txt檔案中每行都是影象路徑,tab鍵,標籤。所以下面程式碼的__init__方法中self.img_name和self.img_label的讀取方式就跟你資料的存放方式有關,你可以根據你實際資料的維護方式做調整。__getitem__方法沒有做太大改動,依然採用default_loader方法來讀取影象。最後在Transform中將每張影象都封裝成Tensor。

class customData(Dataset):
    def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
        with open(txt_path) as input_file:
            lines = input_file.readlines()
            self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
            self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)

        if self.data_transforms is not None:
            try:
                img = self.data_transforms[self.dataset](img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label

定義好了資料讀取介面後,怎麼用呢?

在程式碼中可以這樣呼叫。

 image_datasets = {x: customData(img_path='/ImagePath',
                                    txt_path=('/TxtFile/' + x + '.txt'),
                                    data_transforms=data_transforms,
                                    dataset=x) for x in ['train', 'val']}

這樣返回的image_datasets就和用torchvision.datasets.ImageFolder類返回的資料型別一樣,有點狸貓換太子的感覺,這就是在第一篇部落格中說的寫程式碼類似搭積木的感覺。

有了image_datasets,然後依然用torch.utils.data.DataLoader類來做進一步封裝,將這個batch的影象資料和標籤都分別封裝成Tensor。

 dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True) for x in ['train', 'val']}

另外,每次迭代生成的模型要怎麼儲存呢?非常簡單,那就是用torch.save。輸入就是你的模型和要儲存的路徑及模型名稱,如果這個output資料夾沒有,可以手動新建一個或者在程式碼裡面新建。

torch.save(model, 'output/resnet_epoch{}.pkl'.format(epoch))

最後,關於多GPU的使用,PyTorch支援多GPU訓練模型,假設你的網路是model,那麼只需要下面一行程式碼(呼叫 torch.nn.DataParallel介面)就可以讓後續的模型訓練在0和1兩塊GPU上訓練,加快訓練速度。

 model = torch.nn.DataParallel(model, device_ids=[0,1])