1. 程式人生 > >PyTorch 自定義資料集

PyTorch 自定義資料集

## 準備資料 準備 [COCO128](https://www.kaggle.com/ultralytics/coco128) 資料集,其是 [COCO](https://cocodataset.org) train2017 前 128 個數據。按 YOLOv5 組織的目錄: ```bash $ tree ~/datasets/coco128 -L 2 /home/john/datasets/coco128 ├── images │   └── train2017 │   ├── ... │   └── 000000000650.jpg ├── labels │   └── train2017 │   ├── ... │   └── 000000000650.txt ├── LICENSE └── README.txt ``` 詳見 [Train Custom Data](https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data)。 ## 定義 Dataset `torch.utils.data.Dataset` 是一個數據集的抽象類。自定義資料集時,需繼承 `Dataset` 並覆蓋如下方法: - `__len__`: `len(dataset)` 獲取資料集大小。 - `__getitem__`: `dataset[i]` 訪問第 `i` 個數據。 詳見: - [torch.utils.data.Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) - [torchvision.datasets.vision.VisionDataset](https://github.com/pytorch/vision/blob/master/torchvision/datasets/vision.py) 自定義實現 YOLOv5 資料集的例子: ```python import os from pathlib import Path from typing import Any, Callable, Optional, Tuple import numpy as np import torch import torchvision from PIL import Image class YOLOv5(torchvision.datasets.vision.VisionDataset): def __init__( self, root: str, name: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None, ) -> None: super(YOLOv5, self).__init__(root, transforms, transform, target_transform) images_dir = Path(root) / 'images' / name labels_dir = Path(root) / 'labels' / name self.images = [n for n in images_dir.iterdir()] self.labels = [] for image in self.images: base, _ = os.path.splitext(os.path.basename(image)) label = labels_dir / f'{base}.txt' self.labels.append(label if label.exists() else None) def __getitem__(self, idx: int) -> Tuple[Any, Any]: img = Image.open(self.images[idx]).convert('RGB') label_file = self.labels[idx] if label_file is not None: # found with open(label_file, 'r') as f: labels = [x.split() for x in f.read().strip().splitlines()] labels = np.array(labels, dtype=np.float32) else: # missing labels = np.zeros((0, 5), dtype=np.float32) boxes = [] classes = [] for label in labels: x, y, w, h = label[1:] boxes.append([ (x - w/2) * img.width, (y - h/2) * img.height, (x + w/2) * img.width, (y + h/2) * img.height]) classes.append(label[0]) target = {} target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32) target["labels"] = torch.as_tensor(classes, dtype=torch.int64) if self.transforms is not None: img, target = self.transforms(img, target) return img, target def __len__(self) -> int: return len(self.images) ``` 以上實現,繼承了 `VisionDataset` 子類。其 `__getitem__` 返回了: - image: PIL Image, 大小為 `(H, W)` - target: `dict`, 含以下欄位: - `boxes` (`FloatTensor[N, 4]`): 真實標註框 `[x1, y1, x2, y2]`, `x` 範圍 `[0,W]`, `y` 範圍 `[0,H]` - `labels` (`Int64Tensor[N]`): 上述標註框的類別標識 ## 讀取 Dataset ```python dataset = YOLOv5(Path.home() / 'datasets/coco128', 'train2017') print(f'dataset: {len(dataset)}') print(f'dataset[0]: {dataset[0]}') ``` 輸出: ```bash dataset: 128 datas