1. 程式人生 > >【pytorch原始碼賞析】Dataset in pytorch

【pytorch原始碼賞析】Dataset in pytorch

1. 原始碼概覽

pytorch是眾多dl工具中,比較python風格化的一種,另一個完全python化的dl工具是chainer,它的構建語言中只有python,甚至cuda也是從python端呼叫的。python風格化的好處是,使用了很多python的語言特性,讓程式碼更加簡潔,更高效。《python高階程式設計》的第2、3章,描述了部分python的高階語言特性,比如:列表推導,迭代器和生成器,裝飾器等。這些trick讓程式碼更加python化,可讀性更強,也更健壯。

pytorch的資料集部分,從原始碼可以看出,提供了2個主要的類:Dataset,DataLoader。

Dataset為抽象類,定義了兩個行為:__getitem__和__len__。也就是任何資料集,都可以len(dataset)獲得樣本的數量,dataset[i]獲得其中第i個樣本。派生了兩個類:TensorDataset,當x和y是pytorch的tensor時,可以方便地匯入;另一個ConcatDataset,用於合併多個數據集(對於實際應用特別有用)。

DataLoader是更核心的類,使用者用它來獲得每次batch的訓練資料。

dataloader.py中有2個類,DataLoader和DataLoaderIter。

DataLoader提供如下功能:
1. 儲存了dataset
2. 具有sample行為
3. 提供單執行緒/多執行緒來獲取資料集中的資料(程式碼主要實現的功能)

DataLoader有2個行為:__iter__和__len__。而__iter__這個迭代器,程式碼如下:

def __iter__(self):
    return DataLoaderIter(self)

返回的正是DataLoaderIter。DataLoaderIter的功能是,根據sample指定的方法,獲取訓練樣本。sample方法有SequentialSampler, RandomSampler, BatchSampler這三種,其實是兩種:SequentialSampler和RandomSampler。如果指定了shuffle,則是隨機取樣,否則是序列取樣,然後都會使用BatchSample。

DataLoaderIter具有3個行為:__iter__,__len__和__next__。每次使用next(dataLoaderIter)來獲得一個batch。

__iter__總是和__next__一起使用,__iter__表明這個類是可以迭代的,__next__表明每次迭代的具體行為,一個例子如下:

class Testing:
    def __init__(self,a,b):
        self.a = a
        self.b = b
    def __iter__ (self):
        print('itering')
        return
self def next(self): print('nexting') if self.a <= self.b: self.a += 1 return self.a-1 else: raise StopIteration myObj = Testing(1,5) for i in myObj: print i
itering
nexting
1
nexting
2
nexting
3
nexting
4
nexting
5
nexting

2. 使用方法

使用pytorch提供的方法操作資料集,一般分兩步:
1. 繼承Dataset,實現__getitem____len__方法。
2. 例項化DataLoader,一般需要指定自己的collate_fn方法。

而這正是程式碼優美的地方,把“讀取資料集”這個任務完美地解耦和,使用者只需要針對不同的資料集派生Dataset類,實現2個方法。DataLoader負責瞭如何讀取訓練樣本的行為,只需要例項化即可,還可以通過設定collate_fn定製化自己的具體讀取行為。