1. 程式人生 > >Pytorch打怪路(三)Pytorch建立自己的資料集2

Pytorch打怪路(三)Pytorch建立自己的資料集2

前面一篇寫建立資料集的博文--- 是介紹的應用於影象分類任務的資料集,即輸入為一個影象和它的類別數字標籤本篇介紹輸入的標籤label亦為影象的資料集,幷包含一些常用的處理手段。比如做影象語義分割時就會用到這種資料輸入方式。

1、資料集簡介

以VOC2012資料集為例,影象是RGB3通道的,label是1通道的,(其實label原來是幾通道的無所謂,只要讀取的時候轉化成灰度圖就行)。

訓練資料:

語義label:

這裡我們看到label圖片都是黑色的,只有白色的輪廓而已

其實是因為label圖片裡的畫素值取值範圍是0 ~ 20,即畫素點可能的類別共有21類(對此資料集來說)

,詳情如下:

所以對於灰度值0---20來說,我們肉眼看上去就確實都是黑色的,因為灰度值太低了,而白色的輪廓的灰度值是255

但是這些邊界在計算損失值的時候是不作為有效值的,也就是對於灰度值=255的點是忽略的。

如果想看的話,可以用一些色彩變換,對0--20這每一個數字對應一個色彩,就能看出來了,示例如下

這不是重點,只是給大家看一下方便理解而已

2、文字資訊

同樣有一個文字來指導我對資料的讀取,我的資訊如下

這其實就是一個記載了影象ID的文字文件,連字尾都沒有,但我們依然可以根據這個去資料集中讀取相應的image和label

3、程式碼示例

這個程式碼是我自己在利用deeplabV2 跑semantic segmentation 任務時

寫的一個,也許寫的並不優美,但反正是可以用的,

可以做個拋磚引玉的目的,對於才入門的朋友,理解這個思路就可,不必照搬我的程式碼風格……

import os
import numpy as np
import random
import matplotlib.pyplot as plt
import collections
import torch
import torchvision
import cv2
from PIL import Image
import torchvision.transforms as transforms
from torch.utils import data

class VOCDataSet(data.Dataset):
    def __init__(self, root, list_path,  crop_size=(321, 321), mean=(104.008, 116.669, 122.675), mirror=True, scale=True, ignore_label=255):
        super(VOCDataSet,self).__init__()
        self.root = root
        self.list_path = list_path
        self.crop_h, self.crop_w = crop_size
        self.ignore_label = ignore_label
        self.mean = np.asarray(mean, np.float32)
        self.is_mirror = mirror
        self.is_scale = scale

        self.img_ids = [i_id.strip() for i_id in open(list_path)]

        self.files = []
        for name in self.img_ids:
            img_file = os.path.join(self.root, "JPEGImages/%s.jpg" % name)
            label_file = os.path.join(self.root, "SegmentationClassAug/%s.png" % name)
            self.files.append({
                "img": img_file,
                "label": label_file,
                "name": name
            })

    def __len__(self):
        return len(self.files)


    def __getitem__(self, index):
        datafiles = self.files[index]

        '''load the datas'''
        name = datafiles["name"]
        image = Image.open(datafiles["img"]).convert('RGB')
        label = Image.open(datafiles["label"]).convert('L')
        size_origin = image.size # W * H

        '''random scale the images and labels'''
        if self.is_scale: #如果我在定義dataset時選擇了scale=True,就執行本語句對尺度進行隨機變換
            ratio = 0.5 + random.randint(0, 11) // 10.0 #0.5~1.5
            out_h, out_w = int(size_origin[1]*ratio), int(size_origin[0]*ratio)
            # (H,W)for Resize
            image = transforms.Resize((out_h, out_w), Image.LANCZOS)(image)
            label = transforms.Resize((out_h, out_w), Image.NEAREST)(label)

        '''pad the inputs if their size is smaller than the crop_size'''
        pad_w = max(self.crop_w - out_w, 0)
        pad_h = max(self.crop_h - out_h, 0)
        img_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=0, padding_mode='constant')(image)
        label_pad = transforms.Pad( padding=(0,0,pad_w,pad_h), fill=self.ignore_label, padding_mode='constant')(label)
        out_size = img_pad.size

        '''random crop the inputs'''
        if (self.crop_h != 0 or self.crop_w != 0):
            #select a random start-point for croping operation
            h_off = random.randint(0, out_size[1] - self.crop_h)
            w_off = random.randint(0, out_size[0] - self.crop_w)
            #crop the image and the label
            image = img_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))
            label = label_pad.crop((w_off,h_off, w_off+self.crop_w, h_off+self.crop_h))

        '''mirror operation'''
        if self.is_mirror:
            if np.random.random() < 0.5:
                #0:FLIP_LEFT_RIGHT, 1:FLIP_TOP_BOTTOM, 2:ROTATE_90, 3:ROTATE_180, 4:or ROTATE_270.
                image = image.transpose(0)
                label = label.transpose(0)

        '''convert PIL Image to numpy array'''
        I = np.asarray(image,np.float32) - self.mean
        I = I.transpose((2,0,1))#transpose the  H*W*C to C*H*W
        L = np.asarray(np.array(label), np.int64)
        #print(I.shape,L.shape)
        return I.copy(), L.copy(), np.array(size_origin), name

#這是一個測試函式,也即我的程式碼寫好後,如果直接python運行當前py檔案,就會執行以下程式碼的內容,以檢測我上面的程式碼是否有問題,這其實就是方便我們除錯,而不是每次都去run整個網路再看哪裡報錯
if __name__ == '__main__':
    DATA_DIRECTORY = '/home/teeyo/STA/Data/voc_aug/'
    DATA_LIST_PATH = '../dataset/list/val.txt'
    Batch_size = 4
    MEAN = (104.008, 116.669, 122.675)
    dst = VOCDataSet(DATA_DIRECTORY,DATA_LIST_PATH, mean=(0,0,0))
    # just for test,  so the mean is (0,0,0) to show the original images.
    # But when we are training a model, the mean should have another value
    trainloader = data.DataLoader(dst, batch_size = Batch_size)
    plt.ion()
    for i, data in enumerate(trainloader):
        imgs, labels,_,_= data
        if i%1 == 0:
            img = torchvision.utils.make_grid(imgs).numpy()
            img = img.astype(np.uint8) #change the dtype from float32 to uint8, because the plt.imshow() need the uint8
            img = np.transpose(img, (1, 2, 0))#transpose the Channels*H*W to  H*W*Channels
            #img = img[:, :, ::-1]
            plt.imshow(img)
            plt.show()
            plt.pause(0.5)

            #input()

我個人覺得我應該註釋的地方都有相應的註釋,雖然有點長, 因為實現了crop和翻轉以及scale等功能,但是大家可以下去慢慢揣摩,理解其中的主要思路,與我前一篇的博文做對比,那篇博文相當於是提供了最基本的骨架,而這篇就在骨架上長肉生髮而已,有疑問的歡迎評論探討~~