1. 程式人生 > >Pytorch打怪路(三)Pytorch建立自己的資料集1

Pytorch打怪路(三)Pytorch建立自己的資料集1

之前講的例子,程式都是呼叫的datasets方法,下載的torchvision本身就提供的資料,那麼如果想匯入自己的資料應該怎麼辦呢?

本篇就講解一下如何建立自己的資料集。

1.用於分類的資料集

以mnist資料集為例

這裡的mnist資料集並不是torchvision裡面的,而是我自己的以圖片格式儲存的資料集,因為我在測試STN時,希望自己再把這些手寫體做一些形變,

所以就先把MNIST資料集轉化成了jpg圖片格式,然後做了一些形變,當然這不是重點。首先我們看一下我的資料集的情況:

如圖所示,我的圖片資料集確實是jpg圖片

再看我的儲存圖片名和label資訊的文字:

如圖所示,我的mnist.txt文字每一行分為兩部分,第一部分是具體路徑+圖片名.jpg

第二部分就是label資訊,因為前面這部分圖片都是0 ,所以他們的分類的label資訊就是0

要建立你自己的 用於分類的 資料集,也要包含上述兩個部分,1.圖片資料集,2.文字資訊(這個txt檔案可以用python或者C++輕易建立,再此不詳述)

2.程式碼

主要程式碼

from PIL import Image
import torch

class MyDataset(torch.utils.data.Dataset): #建立自己的類:MyDataset,這個類是繼承的torch.utils.data.Dataset
    def __init__(self,root, datatxt, transform=None, target_transform=None): #初始化一些需要傳入的引數
        fh = open(root + datatxt, 'r') #按照傳入的路徑和txt文字引數,開啟這個文字,並讀取內容
        imgs = []                      #建立一個名為img的空列表,一會兒用來裝東西
        for line in fh:                #按行迴圈txt文字中的內容
            line = line.rstrip()       # 刪除 本行string 字串末尾的指定字元,這個方法的詳細介紹自己查詢python
            words = line.split()   #通過指定分隔符對字串進行切片,預設為所有的空字元,包括空格、換行、製表符等
            imgs.append((words[0],int(words[1]))) #把txt裡的內容讀入imgs列表儲存,具體是words幾要看txt內容而定
                                        # 很顯然,根據我剛才截圖所示txt的內容,words[0]是圖片資訊,words[1]是lable
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):    #這個方法是必須要有的,用於按照索引讀取每個元素的具體內容
        fn, label = self.imgs[index] #fn是圖片path #fn和label分別獲得imgs[index]也即是剛才每行中word[0]和word[1]的資訊
        img = Image.open(root+fn).convert('RGB') #按照path讀入圖片from PIL import Image # 按照路徑讀取圖片

        if self.transform is not None:
            img = self.transform(img) #是否進行transform
        return img,label  #return很關鍵,return回哪些內容,那麼我們在訓練時迴圈讀取每個batch時,就能獲得哪些內容

    def __len__(self): #這個函式也必須要寫,它返回的是資料集的長度,也就是多少張圖片,要和loader的長度作區分
        return len(self.imgs)

#根據自己定義的那個勒MyDataset來建立資料集!注意是資料集!而不是loader迭代器
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
#然後就是呼叫DataLoader和剛剛建立的資料集,來建立dataloader,這裡提一句,loader的長度是有多少個batch,所以和batch_size有關
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)

再補充一點程式碼,以便更好的理解 __getitem__這個方法

for batch_index, data, target in test_loader:
        if use_cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data, volatile=True), Variable(target)

這段程式碼是我從測試的部分中截取出來的,為什麼直接能用for data, target In test_loader這樣的語句呢?

其實這個語句還可以這麼寫:

for batch_index, batch in train_loader

        data, target = batch

這樣就好理解了,因為這個迭代器每一次迴圈所得的batch裡面裝的東西,就是我在__getitem__方法最後return回來的

所以你想在訓練或者測試的時候還得到其他資訊的話,就去增加一些返回值即可,只要是能return出來的,就能在每個batch中讀取到!

###############################################################################

有朋友可能想問,如果我的label資訊不是數字而是影象呢?比如分割任務,它的label就是影象,這樣的資料集的建立,也參考我的下一篇博文: