Pytorch打怪路(三)Pytorch建立自己的資料集1
之前講的例子,程式都是呼叫的datasets方法,下載的torchvision本身就提供的資料,那麼如果想匯入自己的資料應該怎麼辦呢?
本篇就講解一下如何建立自己的資料集。
1.用於分類的資料集
以mnist資料集為例
這裡的mnist資料集並不是torchvision裡面的,而是我自己的以圖片格式儲存的資料集,因為我在測試STN時,希望自己再把這些手寫體做一些形變,
所以就先把MNIST資料集轉化成了jpg圖片格式,然後做了一些形變,當然這不是重點。首先我們看一下我的資料集的情況:
如圖所示,我的圖片資料集確實是jpg圖片
再看我的儲存圖片名和label資訊的文字:
如圖所示,我的mnist.txt文字每一行分為兩部分,第一部分是具體路徑+圖片名.jpg
第二部分就是label資訊,因為前面這部分圖片都是0 ,所以他們的分類的label資訊就是0
要建立你自己的 用於分類的 資料集,也要包含上述兩個部分,1.圖片資料集,2.文字資訊(這個txt檔案可以用python或者C++輕易建立,再此不詳述)
2.程式碼
主要程式碼
from PIL import Image import torch class MyDataset(torch.utils.data.Dataset): #建立自己的類:MyDataset,這個類是繼承的torch.utils.data.Dataset def __init__(self,root, datatxt, transform=None, target_transform=None): #初始化一些需要傳入的引數 fh = open(root + datatxt, 'r') #按照傳入的路徑和txt文字引數,開啟這個文字,並讀取內容 imgs = [] #建立一個名為img的空列表,一會兒用來裝東西 for line in fh: #按行迴圈txt文字中的內容 line = line.rstrip() # 刪除 本行string 字串末尾的指定字元,這個方法的詳細介紹自己查詢python words = line.split() #通過指定分隔符對字串進行切片,預設為所有的空字元,包括空格、換行、製表符等 imgs.append((words[0],int(words[1]))) #把txt裡的內容讀入imgs列表儲存,具體是words幾要看txt內容而定 # 很顯然,根據我剛才截圖所示txt的內容,words[0]是圖片資訊,words[1]是lable self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): #這個方法是必須要有的,用於按照索引讀取每個元素的具體內容 fn, label = self.imgs[index] #fn是圖片path #fn和label分別獲得imgs[index]也即是剛才每行中word[0]和word[1]的資訊 img = Image.open(root+fn).convert('RGB') #按照path讀入圖片from PIL import Image # 按照路徑讀取圖片 if self.transform is not None: img = self.transform(img) #是否進行transform return img,label #return很關鍵,return回哪些內容,那麼我們在訓練時迴圈讀取每個batch時,就能獲得哪些內容 def __len__(self): #這個函式也必須要寫,它返回的是資料集的長度,也就是多少張圖片,要和loader的長度作區分 return len(self.imgs) #根據自己定義的那個勒MyDataset來建立資料集!注意是資料集!而不是loader迭代器 train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor()) test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
#然後就是呼叫DataLoader和剛剛建立的資料集,來建立dataloader,這裡提一句,loader的長度是有多少個batch,所以和batch_size有關
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
再補充一點程式碼,以便更好的理解 __getitem__這個方法
for batch_index, data, target in test_loader: if use_cuda: data, target = data.cuda(), target.cuda() data, target = Variable(data, volatile=True), Variable(target)
這段程式碼是我從測試的部分中截取出來的,為什麼直接能用for data, target In test_loader這樣的語句呢?
其實這個語句還可以這麼寫:
for batch_index, batch in train_loader
data, target = batch
這樣就好理解了,因為這個迭代器每一次迴圈所得的batch裡面裝的東西,就是我在__getitem__方法最後return回來的,
所以你想在訓練或者測試的時候還得到其他資訊的話,就去增加一些返回值即可,只要是能return出來的,就能在每個batch中讀取到!
###############################################################################