1. 程式人生 > >PyTorch(一)——資料處理

PyTorch(一)——資料處理

PyTorch學習和使用(一)

要使學會用一個框架,只會執行其測試實驗是不行的,所以現在打算把caffe中的Siamese模型使用PyTorch實現,來鞏固自己對PyTroch的熟練使用。

資料預處理

首先是資料處理這一塊,PyTorch使用了torchvision來完成資料的處理,其只實現了一些資料集的處理,如果處理自己的工程則需要修改增加內容。

把原始資料處理為模型使用的資料需要3步:transforms.Compose() torchvision.datasets torch.utils.data.DataLoader()分別可以理解為資料處理格式的定義、資料處理和資料載入。

Compose()

程式碼中給出的解釋是Composes several transforms together. 就是通過Compose把一些對影象處理的方法集中起來。比如先中心化,然後轉換為張量(PyTorch的資料結構),其程式碼為:transforms.Compose([transform.CenterCrop(10), transofrms.ToTensor()])又比如先轉換為張量,然後正則化,程式碼為:`transforms.Compose([transofrms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), 其具體的引數呼叫在原始碼中可以看到,在此不多說了。還要注意的是Compose的程式碼是:

def __call_(self, img):
    for t in self.transforms:
        img = t(img)
    return img

這就是把輸入到Compose的操作按順序進行執行。先執行第一個,然後第二個……。如果需要處理自己的資料,可以把具體的操縱放在這個類中實現。

torchvision.datasets ;裡實現了不同針對資料集的處理方法,主要用來載入資料和處理資料。比如在mnist.py 和cifar.py 中用來處理mnist和cifar資料集。類的實現需要繼承父類data.Dataset,其主要方法有2個:

  • __init__(self, root, train=Ture, transform=None, traget_transform=None, download=False):

    該方法用來初始化類和對資料進行載入(有時需要定義一些開關來防止重複處理)。資料的載入就是針對不同的資料,把其data和label(分為訓練資料和測試資料)讀入到記憶體中。

  • --getitem__(self, index):該方法是把讀入的輸出傳給PyTorch(迭代器的方式)。**注意:**上面定義的transform.Compose在次數進行呼叫,通過index確定需要訪問的資料,然後對其格式進行轉換,最後返回處理後的資料。也就是說資料在定義時只是定義了一個類,其具體的資料傳出在需要使用時使用該方法完成。

至此,對資料進行載入,然後處理傳給PyTorch已經完成,如果需要對自己的資料進行處理,也是通過修改和增加此部分完成。接下來需要對訓練的資料進行處理,比如分批次的大小,十分隨機處理等等。

torch.utils.data.DataLoader() Data loder, Combines a dataset and and a sampler, and provides single, or multi-process iterators over the dataset. 就是把合成數據並且提供迭代訪問。輸入引數有:

  • dataset(Dataset)。輸入載入的資料,就是上面的torchvision.datasets.myData()的實現,所以需要繼承data.Dataset,滿足此介面。

  • **batch-size, shuffle, sampler, num_workers, collate_fn, pin_memory, drop_last.**這些引數比較好理解,看名字就知道其作用了。分別為:

  1. batch-size。樣本每個batch的大小,預設為1。
  2. shuffle。是否打亂資料,預設為False。
  3. sampler。定義一個方法來繪製樣本資料,如果定義該方法,則不能使用shuffle。
  4. num_workers。資料分為幾批處理(對於大資料)。
  5. collate_fn。整理資料,把每個batch資料整理為tensor。(一般使用預設呼叫default_collate(batch))。
  6. pin_memory。針對不同型別的batch進行處理。比如為Map或者Squence等型別,需要處理為tensor型別。
  7. drop_last。用於處理最後一個batch的資料。因為最後一個可能不能夠被整除,如果設定為True,則捨棄最後一個,為False則保留最後一個,但是最後一個可能很小。

迭代器(DataLoaderIter)的具體處理就是根據這些引數的設定,分別進行不同的處理。

補充2017/8/10:

torch.utils.data.DataLoader類主要使用torch.utils.data.sampler實現,sampler是所有采樣器的基礎類,提供了迭代器的迭代(__iter__)和長度(__len__)介面實現,同時sampler也是通過索引對資料進行洗牌(shuffle)等操作。因此,如果DataLoader不適用於你的資料,需要重新設計資料的分批次,可以充分使用所提供的smapler