1. 程式人生 > >pytorch 資料處理:定義自己的資料集合

pytorch 資料處理:定義自己的資料集合

資料處理

版本1

#資料處理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np

#定義自己的資料集合
class DogCat(data.Dataset):

    def __init__(self,root):
        #所有圖片的絕對路徑
        imgs=os.listdir(root)

        self.imgs=[os.path.join(root,k) for k in imgs]

    def __getitem__
(self, index):
img_path=self.imgs[index] #dog-> 1 cat ->0 label=1 if 'dog' in img_path.split('/')[-1] else 0 pil_img=Image.open(img_path) array=np.asarray(pil_img) data=torch.from_numpy(array) return data,label def __len__(self): return
len(self.imgs) dataSet=DogCat('./data/dogcat') print(dataSet[0])

輸出:
(
( 0 ,.,.) =
215 203 191
206 194 182
211 199 187

200 191 186
201 192 187
201 192 187

( 1 ,.,.) =
215 203 191
208 196 184
213 201 189

198 189 184
200 191 186
201 192 187

( 2 ,.,.) =
215 201 188
209 195 182
214 200 187

200 191 186
202 193 188
204 195 190

(399,.,.) =
72 90 32
88 106 48
38 56 0

158 161 106
87 85 36
105 98 52
[torch.ByteTensor of size 400x300x3]
, 1)

上面的資料處理有下面的問題:
1.返回的樣本的形狀大小不一致,每一張圖片的大小不一樣。這對於需要batch訓練的神經網路來說很不友好。
2. 返回的資料樣本數值很大,沒有歸一化【-1,1】

對於上面的問題,pytorch torchvision 是一個視覺化的工具包,提供了很多的影象處理的工具,其中transforms模組提供了對PIL image物件和Tensor物件的常用操作。
對PIL Image常見的操作如下;

  • Resize 調整圖片的尺寸,長寬比保持不變

  • CentorCrop ,RandomCrop,RandomSizeCrop 裁剪圖片

  • Pad 填充

  • ToTensor 將PIL Image 轉換為Tensor,會自動將[0,255] 歸一化至[0,1]

對Tensor 的操作如下:

  • Normalize 標準化,即減均值,除以標準差

  • ToPILImage 將Tensor轉換為 PIL Image物件

版本2

#資料處理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms

transform=transforms.Compose([
    transforms.Resize(224), #縮放圖片,保持長寬比不變,最短邊的長為224畫素,
    transforms.CenterCrop(224), #從中間切出 224*224的圖片
    transforms.ToTensor(), #將圖片轉換為Tensor,歸一化至[0,1]
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #標準化至[-1,1]
])

#定義自己的資料集合
class DogCat(data.Dataset):

    def __init__(self,root):
        #所有圖片的絕對路徑
        imgs=os.listdir(root)

        self.imgs=[os.path.join(root,k) for k in imgs]
        self.transforms=transform

    def __getitem__(self, index):
        img_path=self.imgs[index]
        #dog-> 1 cat ->0
        label=1 if 'dog' in img_path.split('/')[-1] else 0
        pil_img=Image.open(img_path)
        if self.transforms:
            data=self.transforms(pil_img)
        else:
            pil_img=np.asarray(pil_img)
            data=torch.from_numpy(pil_img)
        return data,label

    def __len__(self):
        return len(self.imgs)

dataSet=DogCat('./data/dogcat')

print(dataSet[0])

輸出:
(
( 0 ,.,.) =
-0.1765 -0.2627 -0.1686 … -0.0824 -0.2000 -0.2627
-0.2392 -0.3098 -0.3176 … -0.2863 -0.2078 -0.1765
-0.3176 -0.2392 -0.2784 … -0.2941 -0.1137 -0.0118
… ⋱ …
-0.7569 -0.5922 -0.1529 … -0.8510 -0.8196 -0.8353
-0.8353 -0.7255 -0.3255 … -0.8275 -0.8196 -0.8588
-0.9373 -0.7647 -0.4510 … -0.8196 -0.8353 -0.8824

( 1 ,.,.) =
-0.0431 -0.1373 -0.0431 … 0.0118 -0.0980 -0.1529
-0.0980 -0.1686 -0.1765 … -0.1608 -0.0745 -0.0431
-0.1686 -0.0902 -0.1373 … -0.1451 0.0431 0.1529
… ⋱ …
-0.5529 -0.3804 0.0667 … -0.7961 -0.7725 -0.7961
-0.6314 -0.5137 -0.1137 … -0.7804 -0.7882 -0.8275
-0.7490 -0.5608 -0.2392 … -0.7725 -0.8039 -0.8588

[torch.FloatTensor of size 3x224x224]
, 1)