PyTorch 學習筆記(一):讓PyTorch讀取你的資料集
本文擷取自《PyTorch 模型訓練實用教程》,獲取全文pdf請點選:https://github.com/tensor-yu/PyTorch_Tutorial
文章目錄
想讓PyTorch能讀取我們自己的資料,首先要了解pytroch讀取圖片的機制和流程,然後按流程編寫程式碼。
Dataset類
PyTorch讀取圖片,主要是通過Dataset類,所以先簡單瞭解一下Dataset類。Dataset類作為所有的datasets的基類存在,所有的datasets都需要繼承它,類似於C++中的虛基類。
原始碼如下:
class Dataset(object):
"""An abstract class representing a Dataset.
All other datasets should subclass it. All subclasses should override
``__len__``, that provides the size of the dataset, and ``__getitem__``,
supporting integer indexing in range from 0 to len(self) exclusive.
"""
def __getitem__ (self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
這裡重點看 getitem函式,getitem接收一個index,然後返回圖片資料和標籤,這個index通常指的是一個list的index,這個list的每個元素就包含了圖片資料的路徑和標籤資訊。
然而,如何製作這個list呢,通常的方法是將圖片的路徑和標籤資訊儲存在一個txt中,然後從該txt中讀取。
那麼讀取自己資料的基本流程就是:
- 製作儲存了圖片的路徑和標籤資訊的txt
- 將這些資訊轉化為list,該list每一個元素對應一個樣本
- 通過getitem函式,讀取資料和標籤,並返回資料和標籤
在訓練程式碼裡是感覺不到這些操作的,只會看到通過DataLoader就可以獲取一個batch的資料,其實觸發去讀取圖片這些操作的是DataLoader裡的__iter__(self),後面會詳細講解讀取過程。在本小節,主要講Dataset子類。
因此,要讓PyTorch能讀取自己的資料集,只需要兩步:
-
製作圖片資料的索引
-
構建Dataset子類
-
製作圖片資料的索引
這個比較簡單,就是讀取圖片路徑,標籤,儲存到txt檔案中,這裡注意格式就好
特別注意的是,txt中的路徑,是以訓練時的那個py檔案所在的目錄為工作目錄,所以這裡需要提前算好相對路徑!
執行程式碼 Code/1_data_prepare/1_3_generate_txt.py,即會在/Data/資料夾下面看到 train.txt valid.txt
txt中是這樣的:
構建Dataset子類
下面是本實驗構建的Dataset子類——MyDataset類:
# coding: utf-8
from PIL import Image
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, txt_path, transform = None, target_transform = None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, label
def __len__(self):
return len(self.imgs)
首先看看初始化,初始化中從我們準備好的txt裡獲取圖片的路徑和標籤,並且儲存在self.imgs,self.imgs就是上面提到的list,其一個元素對應一個樣本的路徑和標籤,其實就是txt中的一行。
初始化中還會初始化transform,transform是一個Compose型別,裡邊有一個list,list中就會定義了各種對影象進行處理的操作,可以設定減均值,除標準差,隨機裁剪,旋轉,翻轉,仿射變換等操作。
在這裡我們可以知道,一張圖片讀取進來之後,會經過資料處理(資料增強),最終變成輸入模型的資料。這裡就有一點需要注意,PyTorch的資料增強是將原始圖片進行了處理,並不會生成新的一份圖片,而是“覆蓋”原圖,當採用randomcrop之類的隨機操作時,每個epoch輸入進來的圖片幾乎不會是一模一樣的,這達到了樣本多樣性的功能。
然後看看核心的 getitem函式:
第一行:self.imgs 是一個list,也就是一開始提到的list,self.imgs的一個元素是一個str,包含圖片路徑,圖片標籤,這些資訊是從txt檔案中讀取
第二行:利用Image.open對圖片進行讀取,img型別為 Image ,mode=‘RGB’
第三行與第四行: 對圖片進行處理,這個transform裡邊可以實現 減均值,除標準差,隨機裁剪,旋轉,翻轉,放射變換,等等操作,這個放在後面會詳細講解。
當Mydataset構建好,剩下的操作就交給DataLoder,在DataLoder中,會觸發Mydataset中的getiterm函式讀取一張圖片的資料和標籤,並拼接成一個batch返回,作為模型真正的輸入。下一小節將會通過一個小例子,介紹DataLoder是如何獲取一個batch,以及一張圖片是如何被PyTorch讀取,最終變為模型的輸入的。
轉載請註明出處:https://blog.csdn.net/u011995719/article/details/85102770