Pytorch Dataset & Dataloader

Pytorch框架下的工具包中,提供了資料處理的兩個重要介面,Dataset 和 Dataloader,能夠方便的使用和載入自己的資料集。

  1. 資料的預處理,載入資料並轉化為tensor格式

  2. 使用Dataset構建自己的資料

  3. 使用Dataloader裝載資料

【資料】連結:https://pan.baidu.com/s/1gdWFuUakuslj-EKyfyQYLA

提取碼:10d4

複製這段內容後開啟百度網盤手機App,操作更方便哦

資料的預處理與載入

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset ## 1. 資料的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])

torch.utils.data.Dataset

Dataset抽象類,用於包裝構建自己的資料集,該類包括三個基本的方法:

  • __init__ 進行資料的讀取操作
  • __getitem__ 資料集需支援索引訪問
  • __len__ 返回資料集的長度
## 2. 構建自己的資料集
class Mydataset(Dataset):
def __init__(self, train_data, label_data):
self.train = train_data
self.label = label_data
self.len = len(train_data) def __getitem__(self, item):
return self.train[item], self.label[item] def __len__(self):
return self.len dataset = Mydataset(x, y)
samples = dataset.__len__()
print("總樣本數:",samples)

torch.utils.data.Dataloader

Dataloader抽象類,構建可迭代的資料集裝載器,從Dataset例項物件中按batch_size裝載資料以送入訓練。包含以下幾個引數:

  • batch_size 批大小
  • shuffle 裝載的batch是否亂序
  • drop_last 不足batch大小的最後部分是否捨去
  • num_workers 是否多程序讀取資料
## 3. 建立資料集裝載器
train_loader = DataLoader(dataset=dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=4)

測試

if __name__ == "__main__":
iteration = 0
for train_data, train_label in train_loader:
print("x: ", train_data, "\ny: ", train_label)
iteration += 1
### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
print("每個epoch迭代次數:",iteration)

完整程式碼

import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset ## 1. 資料的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :]) ## 2. 構建自己的資料集
class Mydataset(Dataset):
def __init__(self, train_data, label_data):
self.train = train_data
self.label = label_data
self.len = len(train_data) def __getitem__(self, item):
return self.train[item], self.label[item] def __len__(self):
return self.len dataset = Mydataset(x, y) ## 3. 建立資料集裝載器
train_loader = DataLoader(dataset=dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=4) if __name__ == "__main__":
iteration = 0
samples = dataset.__len__()
print("總樣本數:", samples)
for train_data, train_label in train_loader:
print("x: ", train_data, "\ny: ", train_label)
iteration += 1
### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
print("每個epoch迭代次數:",iteration)