1. 程式人生 > >PyTorch 學習筆記(一):讓PyTorch讀取你的資料集

PyTorch 學習筆記(一):讓PyTorch讀取你的資料集

本文擷取自《PyTorch 模型訓練實用教程》,獲取全文pdf請點選:https://github.com/tensor-yu/PyTorch_Tutorial

文章目錄

想讓PyTorch能讀取我們自己的資料,首先要了解pytroch讀取圖片的機制和流程,然後按流程編寫程式碼。

Dataset類

PyTorch讀取圖片,主要是通過Dataset類,所以先簡單瞭解一下Dataset類。Dataset類作為所有的datasets的基類存在,所有的datasets都需要繼承它,類似於C++中的虛基類。

原始碼如下:

class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__
(self, index): raise NotImplementedError def __len__(self): raise NotImplementedError def __add__(self, other): return ConcatDataset([self, other])

這裡重點看 getitem函式,getitem接收一個index,然後返回圖片資料和標籤,這個index通常指的是一個list的index,這個list的每個元素就包含了圖片資料的路徑和標籤資訊。

然而,如何製作這個list呢,通常的方法是將圖片的路徑和標籤資訊儲存在一個txt中,然後從該txt中讀取。
那麼讀取自己資料的基本流程就是:

  1. 製作儲存了圖片的路徑和標籤資訊的txt
  2. 將這些資訊轉化為list,該list每一個元素對應一個樣本
  3. 通過getitem函式,讀取資料和標籤,並返回資料和標籤

在訓練程式碼裡是感覺不到這些操作的,只會看到通過DataLoader就可以獲取一個batch的資料,其實觸發去讀取圖片這些操作的是DataLoader裡的__iter__(self),後面會詳細講解讀取過程。在本小節,主要講Dataset子類。
因此,要讓PyTorch能讀取自己的資料集,只需要兩步:

  1. 製作圖片資料的索引

  2. 構建Dataset子類

  3. 製作圖片資料的索引
    這個比較簡單,就是讀取圖片路徑,標籤,儲存到txt檔案中,這裡注意格式就好
    特別注意的是,txt中的路徑,是以訓練時的那個py檔案所在的目錄為工作目錄,所以這裡需要提前算好相對路徑!
    執行程式碼 Code/1_data_prepare/1_3_generate_txt.py,即會在/Data/資料夾下面看到 train.txt valid.txt
    txt中是這樣的:
    在這裡插入圖片描述

構建Dataset子類

下面是本實驗構建的Dataset子類——MyDataset類:

# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
	fh = open(txt_path, 'r')
	imgs = []
	for line in fh:
		line = line.rstrip()
		words = line.split()
		imgs.append((words[0], int(words[1])))
		self.imgs = imgs 
		self.transform = transform
		self.target_transform = target_transform
def __getitem__(self, index):
	fn, label = self.imgs[index]
	img = Image.open(fn).convert('RGB') 
	if self.transform is not None:
		img = self.transform(img) 
	return img, label
def __len__(self):
	return len(self.imgs)

首先看看初始化,初始化中從我們準備好的txt裡獲取圖片的路徑和標籤,並且儲存在self.imgs,self.imgs就是上面提到的list,其一個元素對應一個樣本的路徑和標籤,其實就是txt中的一行。

初始化中還會初始化transform,transform是一個Compose型別,裡邊有一個list,list中就會定義了各種對影象進行處理的操作,可以設定減均值,除標準差,隨機裁剪,旋轉,翻轉,仿射變換等操作。

在這裡我們可以知道,一張圖片讀取進來之後,會經過資料處理(資料增強),最終變成輸入模型的資料。這裡就有一點需要注意,PyTorch的資料增強是將原始圖片進行了處理,並不會生成新的一份圖片,而是“覆蓋”原圖,當採用randomcrop之類的隨機操作時,每個epoch輸入進來的圖片幾乎不會是一模一樣的,這達到了樣本多樣性的功能。

然後看看核心的 getitem函式:

第一行:self.imgs 是一個list,也就是一開始提到的list,self.imgs的一個元素是一個str,包含圖片路徑,圖片標籤,這些資訊是從txt檔案中讀取

第二行:利用Image.open對圖片進行讀取,img型別為 Image ,mode=‘RGB’

第三行與第四行: 對圖片進行處理,這個transform裡邊可以實現 減均值,除標準差,隨機裁剪,旋轉,翻轉,放射變換,等等操作,這個放在後面會詳細講解。

當Mydataset構建好,剩下的操作就交給DataLoder,在DataLoder中,會觸發Mydataset中的getiterm函式讀取一張圖片的資料和標籤,並拼接成一個batch返回,作為模型真正的輸入。下一小節將會通過一個小例子,介紹DataLoder是如何獲取一個batch,以及一張圖片是如何被PyTorch讀取,最終變為模型的輸入的。

轉載請註明出處:https://blog.csdn.net/u011995719/article/details/85102770