1. 程式人生 > >pytorch使用(一)處理並載入自己的資料

pytorch使用(一)處理並載入自己的資料

pytorch使用(一)資料處理

個人認為,資料處理或許是在完成一篇論文中最耗費時間的,特別是大多情況下,需要在很多個庫上做實驗。

pytorch官方支援很多庫,使用torchvision來完成資料的處理,點這裡可以看到支援的庫並不是很多。在這裡,我將結合一個例項說明如何使用pytorch來處理自己的資料,任務是一個分析雙臂運動的,檢測6個關節點的運動。輸入是連續三幀的檢測結果以及計算的光流,也就是$3*6+2*2=22$張heatmap,輸出是中間幀的檢測結果,也就是6張heatmap。

把原始資料處理為模型使用的資料需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分別可以理解為資料處理格式的定義、資料處理和資料載入。

1. 資料預處理torchvision.transforms

pytorch使用torchvision.transforms實現資料的預處理,包括中心化(torchvision.transforms.CenterCrop)、隨機剪下(torchvision.transforms.RandomCrop)、正則化、圖片變為Tensor、tensor變為圖片等,建議整體瀏覽一下這一部分的官方手冊,非常有用,資料處理很方便。

先轉換為張量,然後正則化:

import torchvision.transforms as transforms
transform = transforms.Compose
([transforms.ToTensor(), transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))]) #img = transform(img)

2. 資料讀取,構建Dataset子類

如果想要使用自己的資料,則必須自己構建一個torch.utils.data.Dataset的子類去讀取資料。我們的將資料列表放在train.txttest.txt中,將不同型別的資料的路徑放在path.txt中,所以在類的init函式中有path_file和 list_file兩個變數

在定義torch.utils.data.Dataset的子類時,必須過載的兩個函式是lengetitem:
- len返回資料集的大小
- getitem實現資料集的下標索引,返回對應的影象和標記(不一定非得返回影象和標記,返回元組的長度可以是任意長,這由網路需要的資料決定)。

末尾有自己寫的一個Dataset子類的定義檔案。

3. 資料載入

torch.utils.data.DataLoader()函式,合成數據並且提供迭代訪問。主要由兩部分組成:
- dataset(Dataset)。輸入載入的資料,就是上面的MyDataset的實現。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等引數,介紹幾個比較常用的,這些在官方網站都有:

- batch-size。樣本每個batch的大小,預設為1。
- shuffle。是否打亂資料,預設為False。
- num_workers。資料分為幾個執行緒處理預設為0。
- sampler。定義一個方法來繪製樣本資料,如果定義該方法,則不能使用shuffle。預設為False

使用:

import torch
from datagen import MyDataset

trainset = MyDataset(path_file=pathFile,list_file=trainList,numJoints = 6,type=False)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=8)
testset = MyDataset(path_file=pathFile,list_file=testList,numJoints = 6,type=False)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False, num_workers=8)
以下是定義class MyDataset檔案datagen.py, 其中有__init__(self, path_file, list_file,numJoints,type)__getitem__(self, idx)__len__(self)三個函式,__getitem__返回一個(22,256,256)的輸入和一個(6,256,256)的標籤。
'''
Load data
'''

import numpy as np
from PIL import Image
#import cv2

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

class MyDataset(data.Dataset):

    def __init__(self, path_file, list_file,numJoints,type):
        '''
        Args:
          path_file: (str) heatmap and optical file location
          list_file: (str) path to index file.
          numJoints: (int) number of joints
          type: (boolean) use pose flow(true) or optical flow(false)
        '''

        self.numJoints = numJoints

        # read heatmap and optical path
        with open(path_file) as f:
            paths = f.readlines()

        for path in paths:
            splited = path.strip().split()
            if splited[0]=='resPath':
                self.resPath = splited[1]
            elif splited[0]=='gtPath':
                self.gtPath = splited[1]
            elif splited[0]=='opticalFlowPath':
                self.opticalFlowPath = splited[1]
            elif splited[0]=='poseFlowPath':
                self.poseFlowPath = splited[1]
        if type:
            self.flowPath = self.poseFlowPath
        else:
            self.flowPath = self.opticalFlowPath


        #read list
        with open(list_file) as f:
            self.list = f.readlines()
            self.num_samples = len(self.list)

def __getitem__(self, idx):
    '''
    load heatmaps and optical flow and encode it to a 22 channels input and 6 channels output
    :param idx: (int) image index
    :return:
        input: a 22 channel input which integrate 2 optical flow and heatmaps of 3 image
        output: the ground truth

    '''

    input = []
    output = []
    # load heatmaps of 3 image
    for im in range(3):
        for map in range(6):
            curResPath = self.resPath + self.list[idx].rstrip('\n') + str(im + 1) + '/' + str(map + 1) + '.bmp'
            heatmap = Image.open(curResPath)
            heatmap.load()
            heatmap = np.asarray(heatmap, dtype='float') / 255
            input.append(heatmap)
    # load 2 flow
    for flow in range(2):
        curFlowXPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowx/' + str(flow + 1) + '.jpg'
        flowX = Image.open(curFlowXPath)
        flowX.load()
        flowX = np.asarray(flowX, dtype='float')
        curFlowYPath = self.flowPath + self.list[idx].rstrip('\n') + 'flowy/' + str(flow + 1) + '.jpg'
        flowY = Image.open(curFlowYPath)
        flowY.load()
        flowY = np.asarray(flowY, dtype='float')
        input.append(flowX)
        input.append(flowY)
    # load groundtruth
    for map in range(6):
        curgtPath = self.resPath + self.list[idx].rstrip('\n') + str(2) + '/' + str(map + 1) + '.bmp'
        heatmap = Image.open(curResPath)
        heatmap.load()
        heatmap = np.asarray(heatmap, dtype='float') / 255
        output.append(heatmap)

    input = torch.Tensor(input)
    output = torch.Tensor(output)

    return input,output



def __len__(self):
    return self.num_samples