1. 程式人生 > >pytorch載入pascal&&coco資料集

pytorch載入pascal&&coco資料集

上一篇部落格https://blog.csdn.net/goodxin_ie/article/details/84315458我們詳細介紹了pascal&&coco資料集,本篇我們將介紹pytorch如何載入

一、目標

pascal資料集的資料來源是jpg圖片,便籤是xml檔案,而pytorch運算使用的資料是Tensor。因此我們的目標是將jpg和xml檔案轉化為可供程式運算使用的Tensor或者numpy型別(Tesnor和numpy可以相互轉化)。

回憶一下目標檢測演算法需要的標籤資訊,有類別和bbox框。在pascal資料集中,每張圖片中的物件由xml中的objec標定,每個物件存在類別名name,位置框('ymin', 'xmin', 'ymax', 'xmax'),是否為困難樣本的標記difficult。

二、解析xml檔案

呼叫ElementTree元素樹可以很方便的解析出xml檔案的各種資訊。我們主要使用其中的find方法查詢對應屬性的資訊

ET.findall('object')   #查詢物件
ET.findall('bndbox')   #查詢位置框

完整的解析pasacal中xml檔案程式碼如下:

輸入引數:路徑,檔名,是否使用困難樣本

輸出: bbox,label,difficult   (型別np.float32)

def parseXml(data_dir,id,use_difficult=False):
        anno = ET.parse(
            os.path.join(data_dir, 'Annotations', id + '.xml'))
        bbox = list()
        label = list()
        difficult = list()
        for obj in anno.findall('object'):
            if not use_difficult and int(obj.find('difficult').text) == 1:
                continue
            difficult.append(int(obj.find('difficult').text))
            bndbox_anno = obj.find('bndbox')

            bbox.append([
                int(bndbox_anno.find(tag).text) - 1
                for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
            name = obj.find('name').text.lower().strip()
            label.append(VOC_BBOX_LABEL_NAMES.index(name))
        bbox = np.stack(bbox).astype(np.float32)     #from list to array
        label = np.stack(label).astype(np.int32)

        difficult = np.array(difficult, dtype=np.bool).astype(np.uint8)  # PyTorch don't support np.bool
        return  bbox, label, difficult