1. 程式人生 > >關於pytorch影象處理模組的資料處理

關於pytorch影象處理模組的資料處理

文章參考:chsasank.github.io

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
#搭建影象處理的框架
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}
#影象儲存路徑
data_dir = 'data'
#遍歷影象
#
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train', 'val']}
#檢視對應資料夾的labels  輸出結果是:{'bees': 1, 'ants': 0}
print(datasets.ImageFolder('data/train').class_to_idx)
#檢視對應資料夾的labels  輸出結果是一個列表,比如其中某個元素為如下形式 ('data/train/ants/0013035.jpg', 0),
#print(datasets.ImageFolder('data/train').imgs)

#output class_names  and the result is  ['ants', 'bees']
class_names = image_datasets['train'].classes

#輸出dataset_sizes為{train的圖片數與val的圖片數}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

#變成張量資料 模型可以直接呼叫,具體可參考莫煩pytorch教程
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4) for x in ['train', 'val']}

for inputs,labels in dataloaders['train']:
    #這裡輸出的就是batch圖片的內容,維度為4*3*228*228

    print(inputs)

print(dataloaders['train'][0])
for i in range(len(image_datasets['train'].imgs)):
    if i < 10:
        #顯示輸出為單個的元素  ('data/train/ants/1030023514_aad5c608f9.jpg', 0)
        #print(image_datasets['train'].imgs[i])
        #如果需要顯示單張影象的類別
        print(image_datasets['train'].classes)               
print(len(image_datasets['train'].imgs))