Pytorch Dataset & Dataloader
Pytorch框架下的工具包中,提供了資料處理的兩個重要介面,Dataset 和 Dataloader,能夠方便的使用和載入自己的資料集。
資料的預處理,載入資料並轉化為tensor格式
使用Dataset構建自己的資料
使用Dataloader裝載資料
【資料】連結:https://pan.baidu.com/s/1gdWFuUakuslj-EKyfyQYLA
提取碼:10d4
複製這段內容後開啟百度網盤手機App,操作更方便哦
資料的預處理與載入
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
## 1. 資料的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
torch.utils.data.Dataset
Dataset抽象類,用於包裝構建自己的資料集,該類包括三個基本的方法:
- __init__ 進行資料的讀取操作
- __getitem__ 資料集需支援索引訪問
- __len__ 返回資料集的長度
## 2. 構建自己的資料集
class Mydataset(Dataset):
def __init__(self, train_data, label_data):
self.train = train_data
self.label = label_data
self.len = len(train_data)
def __getitem__(self, item):
return self.train[item], self.label[item]
def __len__(self):
return self.len
dataset = Mydataset(x, y)
samples = dataset.__len__()
print("總樣本數:",samples)
torch.utils.data.Dataloader
Dataloader抽象類,構建可迭代的資料集裝載器,從Dataset例項物件中按batch_size裝載資料以送入訓練。包含以下幾個引數:
- batch_size 批大小
- shuffle 裝載的batch是否亂序
- drop_last 不足batch大小的最後部分是否捨去
- num_workers 是否多程序讀取資料
## 3. 建立資料集裝載器
train_loader = DataLoader(dataset=dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=4)
測試
if __name__ == "__main__":
iteration = 0
for train_data, train_label in train_loader:
print("x: ", train_data, "\ny: ", train_label)
iteration += 1
### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
print("每個epoch迭代次數:",iteration)
完整程式碼
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
## 1. 資料的處理,載入轉化為tensor
x_data = 'X.csv'
y_data = 'y.csv'
x = np.loadtxt(x_data, delimiter=' ', dtype=np.float32)
y = np.loadtxt(y_data, delimiter=' ', dtype=np.float32).reshape(-1, 1)
x = torch.from_numpy(x[:, :])
y = torch.from_numpy(y[:, :])
## 2. 構建自己的資料集
class Mydataset(Dataset):
def __init__(self, train_data, label_data):
self.train = train_data
self.label = label_data
self.len = len(train_data)
def __getitem__(self, item):
return self.train[item], self.label[item]
def __len__(self):
return self.len
dataset = Mydataset(x, y)
## 3. 建立資料集裝載器
train_loader = DataLoader(dataset=dataset,
batch_size=64,
shuffle=True,
drop_last=True,
num_workers=4)
if __name__ == "__main__":
iteration = 0
samples = dataset.__len__()
print("總樣本數:", samples)
for train_data, train_label in train_loader:
print("x: ", train_data, "\ny: ", train_label)
iteration += 1
### 這裡dataloader中drop_last為True,所以迭代次數應為 samples/batch_size = 6
print("每個epoch迭代次數:",iteration)