1. 程式人生 > >行人重識別(ReID) ——基於Person_reID_baseline_pytorch修改業務流程

行人重識別(ReID) ——基於Person_reID_baseline_pytorch修改業務流程

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training 
│   ├── gt_bbox/                    /* We do not use it 
│   ├── gt_query/                   /* Files for multiple query testing 
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt

修改--test_dir路徑,執行python prepare.py之後的資料集結構:

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training 
│   ├── gt_bbox/                    /* We do not use it 
│   ├── gt_query/                   /* Files for multiple query testing 
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt
│   ├── pytorch/
│       ├── train/                   /* train 
│           ├── 0002
|           ├── 0007
|           ...
│       ├── val/                     /* val
│       ├── train_all/               /* train+val      
│       ├── query/                   /* query files  
│       ├── gallery/                 /* gallery files

訓練模型並測試,修改train.py、test.py中的--test_dir路徑/home/hylink/eclipse-workspace/reID/Market/pytorch

python train.py
python test.py
python demo.py --query_index 777

效果展示:
在這裡插入圖片描述

修改test.py(將原gallery和query生成底庫,改為只生成gallery底庫)

# -*- coding: utf-8 -*-

from __future__ import print_function, division

import argparse
import
torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler from torch.autograd import Variable import numpy as np import torchvision from torchvision import datasets, models, transforms import time import os import scipy.io from model import ft_net, ft_net_dense, PCB, PCB_test ###################################################################### # Options # -------- parser = argparse.ArgumentParser(description='Training') parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data') parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path') parser.add_argument('--batchsize', default=32, type=int, help='batchsize') parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) parser.add_argument('--PCB', action='store_true', help='use PCB' ) parser.add_argument('--multi', action='store_true', help='use multiple query' ) opt = parser.parse_args() str_ids = opt.gpu_ids.split(',') #which_epoch = opt.which_epoch name = opt.name test_dir = opt.test_dir gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >=0: gpu_ids.append(id) # set gpu ids if len(gpu_ids)>0: torch.cuda.set_device(gpu_ids[0]) ###################################################################### # Load Data # --------- # # We will use torchvision and torch.utils.data packages for loading the # data. # data_transforms = transforms.Compose([ transforms.Resize((288,144), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ############### Ten Crop #transforms.TenCrop(224), #transforms.Lambda(lambda crops: torch.stack( # [transforms.ToTensor()(crop) # for crop in crops] # )), #transforms.Lambda(lambda crops: torch.stack( # [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop) # for crop in crops] # )) ]) if opt.PCB: data_transforms = transforms.Compose([ transforms.Resize((384,192), interpolation=3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) data_dir = test_dir if opt.multi: image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} else: image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery']} dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, shuffle=False, num_workers=16) for x in ['gallery']} #class_names = image_datasets['query'].classes use_gpu = torch.cuda.is_available() ###################################################################### # Load model #--------------------------- def load_network(network): save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch) network.load_state_dict(torch.load(save_path)) return network ###################################################################### # Extract feature # ---------------------- # # Extract feature from a trained model. # def fliplr(img): '''flip horizontal''' inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W img_flip = img.index_select(3,inv_idx) return img_flip def extract_feature(model,dataloaders): features = torch.FloatTensor() count = 0 for data in dataloaders: img, label = data n, c, h, w = img.size() count += n print(count) if opt.use_dense: ff = torch.FloatTensor(n,1024).zero_() else: ff = torch.FloatTensor(n,2048).zero_() if opt.PCB: ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts for i in range(2): if(i==1): img = fliplr(img) input_img = Variable(img.cuda()) outputs = model(input_img) f = outputs.data.cpu() ff = ff+f # norm feature if opt.PCB: # feature size (n,2048,6) # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) ff = ff.div(fnorm.expand_as(ff)) ff = ff.view(ff.size(0), -1) else: fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) ff = ff.div(fnorm.expand_as(ff)) features = torch.cat((features,ff), 0) return features def get_id(img_path): camera_id = [] labels = [] for path, v in img_path: #filename = path.split('/')[-1] filename = os.path.basename(path) label = filename[0:4] camera = filename.split('c')[1] if label[0:2]=='-1': labels.append(-1) else: labels.append(int(label)) camera_id.append(int(camera[0])) return camera_id, labels gallery_path = image_datasets['gallery'].imgs #query_path = image_datasets['query'].imgs gallery_cam,gallery_label = get_id(gallery_path) #query_cam,query_label = get_id(query_path) if opt.multi: mquery_path = image_datasets['multi-query'].imgs mquery_cam,mquery_label = get_id(mquery_path) ###################################################################### # Load Collected data Trained model print('-------test-----------') if opt.use_dense: model_structure = ft_net_dense(751) else: model_structure = ft_net(751) if opt.PCB: model_structure = PCB(751) model = load_network(model_structure) # Remove the final fc layer and classifier layer if not opt.PCB: model.model.fc = nn.Sequential() model.classifier = nn.Sequential() else: model = PCB_test(model) # Change to test mode model = model.eval() if use_gpu: model = model.cuda() # Extract feature gallery_feature = extract_feature(model,dataloaders['gallery']) #query_feature = extract_feature(model,dataloaders['query']) if opt.multi: mquery_feature = extract_feature(model,dataloaders['multi-query']) # Save to Matlab for check #result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam} scipy.io.savemat('pytorch_result.mat',result) if opt.multi: result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} scipy.io.savemat('multi_query.mat',result)

修改demo.py(將query路徑下的圖片生成特徵並於gallery底庫進行比對並展示)

# -*- coding: utf-8 -*-

from __future__ import print_function, division

import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import os
import scipy.io
import matplotlib.pyplot as plt
from model import ft_net, ft_net_dense, PCB, PCB_test

######################################################################
# Options
# --------
parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0  0,1,2  0,2')
parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--test_dir',default='/home/hylink/eclipse-workspace/reID/Market/pytorch',type=str, help='./test_data')
parser.add_argument('--name', default='ft_ResNet50', type=str, help='save model path')
parser.add_argument('--batchsize', default=32, type=int, help='batchsize')
parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
parser.add_argument('--PCB', action='store_true', help='use PCB' )
parser.add_argument('--multi', action='store_true', help='use multiple query' )
parser.add_argument('--query_index', default=3, type=int, help='test_image_index')

opt = parser.parse_args()

str_ids = opt.gpu_ids.split(',')
#which_epoch = opt.which_epoch
name = opt.name
test_dir = opt.test_dir

gpu_ids = []
for str_id in str_ids:
    id = int(str_id)
    if id >=0:
        gpu_ids.append(id)

# set gpu ids
if len(gpu_ids)>0:
    torch.cuda.set_device(gpu_ids[0])

######################################################################
# Load Data
# ---------
#
# We will use torchvision and torch.utils.data packages for loading the
# data.
#
data_transforms = transforms.Compose([
        transforms.Resize((288,144), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
############### Ten Crop        
        #transforms.TenCrop(224),
        #transforms.Lambda(lambda crops: torch.stack(
         #   [transforms.ToTensor()(crop) 
          #      for crop in crops]
           # )),
        #transforms.Lambda(lambda crops: torch.stack(
         #   [transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(crop)
          #       for crop in crops]
          # ))
])

if opt.PCB:
    data_transforms = transforms.Compose([
        transforms.Resize((384,192), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 
    ])


data_dir = test_dir

if opt.multi:
    image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
else:
    image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query']}
    dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                             shuffle=False, num_workers=16) for x in ['gallery','query']}
class_names = image_datasets['query'].classes
use_gpu = torch.cuda.is_available()

######################################################################
# Load model
#---------------------------
def load_network(network):
    save_path = os.path.join('./model',name,'net_%s.pth'%opt.which_epoch)
    network.load_state_dict(torch.load(save_path))
    return network


######################################################################
# Extract feature
# ----------------------
#
# Extract feature from  a trained model.
#
def fliplr(img):
    '''flip horizontal'''
    inv_idx = torch.arange(img.size(3)-1,-1,-1).long()  # N x C x H x W
    img_flip = img.index_select(3,inv_idx)
    return img_flip

def extract_feature(model,dataloaders):
    features = torch.FloatTensor()
    count = 0
    for data in dataloaders:
        img, label = data
        n, c, h, w =