1. 程式人生 > >fast rcnn 訓練自己的資料集(訓練和檢測)

fast rcnn 訓練自己的資料集(訓練和檢測)

1.預訓練模型介紹

首先在data目錄下,有兩個目錄就是之前在1中解壓好

  • fast_rcnn_models/
  • imagenet_models/

fast_rcnn_model資料夾下面是作者用fast rcnn訓練好的三個網路,分別對應著小、中、大型網路,大家可以試用一下這幾個網路,看一些檢測效果,他們訓練都迭代了40000次,資料集都是pascal_voc的資料集。

  1. caffenet_fast_rcnn_iter_40000.caffemodel
  2. vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel
  3. vgg16_fast_rcnn_iter_40000.caffemodel

imagenet_model資料夾下面是在Imagenet上訓練好的通用模型,在這裡用來初始化網路的引數

  1. CaffeNet.v2.caffemodel
  2. VGG_CNN_M_1024.v2.caffemodel
  3. VGG16.v2.caffemodel

在這裡我比較推薦先用中型網路訓練,中型網路訓練和檢測的速度都比較快,效果也都比較理想,大型網路的話訓練速度比較慢,我當時是5000多個標註資訊,網路配置預設,中型網路訓練大概兩三個小時,大型網路的話用十幾個小時,需要注意的是網路訓練最好用GPU,CPU的話太慢了,我當時用的實驗室的伺服器,有16塊Tesla K80,用起來真的是灰常爽!

2. 修改模型檔案配置

模型檔案在models下面對應的網路資料夾下,在這裡我用中型網路的配置檔案修改為例子
比如:我的檢測目標物是car ,那麼我的類別就有兩個類別即 background 和 car


因此,首先開啟網路的模型資料夾,開啟train.prototxt
修改的地方重要有三個
分別是個地方

  1. 首先在data層把num_classes 從原來的21類 20類+背景 ,改成 2類 車+背景
  2. 接在在cls_score層把num_output 從原來的21 改成 2
  3. 在bbox_pred層把num_output 從原來的84 改成8, 為檢測類別個數乘以4,比如這裡是2類那就是2*4=8

OK,如果你要進一步修改網路訓練中的學習速率,步長,gamma值,以及輸出模型的名字,需要在同目錄下的solver.prototxt中修改。
如下圖:

train_net: "models/VGG_CNN_M_1024/train.prototxt"
base_lr: 0.001 lr_policy: "step" gamma: 0.1 stepsize: 30000 display: 20 average_loss: 100 momentum: 0.9 weight_decay: 0.0005 # We disable standard caffe solver snapshotting and implement our own snapshot # function snapshot: 0 # We still use the snapshot prefix, though snapshot_prefix: "vgg_cnn_m_1024_fast_rcnn" #debug_info: true

3.啟動Fast RCNN網路訓練

啟動訓練:
./tools/train_net.py --gpu 11 --solver models/VGG_CNN_M_1024_LOUYIHANG/solver.prototxt --weights data/imagenet_models/VGG_CNN_M_1024.v2.caffemodel --imdb KakouTrain

引數講解

  • 這裡的--是兩個-,markdown寫的,大家不要輸錯
  • train_net.py是網路的訓練檔案,之後的引數都是附帶的輸入引數
  • --gpu 代表機器上的GPU編號,如果是nvidia系列的tesla顯示卡,可以在終端中輸入nvidia-smi來檢視當前的顯示卡負荷,選擇合適的顯示卡
  • --solver 代表模型的配置檔案,train.prototxt的檔案路徑已經包含在這個檔案之中
  • --weights 代表初始化的權重檔案,這裡用的是Imagenet上預訓練好的模型,中型的網路我們選擇用VGG_CNN_M_1024.v2.caffemodel
  • --imdb 這裡給出的訓練的資料庫名字需要在factory.py的__sets中,我在檔案裡面有__sets['KakouTrain'],train_net.py這個檔案會呼叫factory.py再生成kakou這個類,來讀取資料

4.啟動Fast RCNN網路檢測

我修改了tools下面的demo.py這個檔案,用來做檢測,並且將檢測的座標結果輸出到相應的txt檔案中
可以看到原始的demo.py 是用網路測試了兩張影象,並做視覺化輸出,有具體的檢測效果,但是我是在Linux伺服器的終端下,沒有display device,因此部分程式碼要少做修改

下面是原始的demo.py:

#!/usr/bin/env python

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""

import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from utils.cython_nms import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse

CLASSES = ('__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor')

NETS = {'vgg16': ('VGG16',
                  'vgg16_fast_rcnn_iter_40000.caffemodel'),
        'vgg_cnn_m_1024': ('VGG_CNN_M_1024',
                           'vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),
        'caffenet': ('CaffeNet',
                     'caffenet_fast_rcnn_iter_40000.caffemodel')}


def vis_detections(im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect='equal')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor='red', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                '{:s} {:.3f}'.format(class_name, score),
                bbox=dict(facecolor='blue', alpha=0.5),
                fontsize=14, color='white')

    ax.set_title(('{} detections with '
                  'p({} | box) >= {:.1f}').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis('off')
    plt.tight_layout()
    plt.draw()

def demo(net, image_name, classes):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load pre-computed Selected Search object proposals
    box_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo',
                            image_name + '_boxes.mat')
    obj_proposals = sio.loadmat(box_file)['boxes']

    # Load the demo image
    im_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name + '.jpg')
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im, obj_proposals)
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls in classes:
        cls_ind = CLASSES.index(cls)
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,
                                                                    CONF_THRESH)
        vis_detections(im, cls, dets, thresh=CONF_THRESH)

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--cpu', dest='cpu_mode',
                        help='Use CPU mode (overrides --gpu)',
                        action='store_true')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                        choices=NETS.keys(), default='vgg16')

    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()

    prototxt = os.path.join(cfg.ROOT_DIR, 'models', NETS[args.demo_net][0],
                            'test.prototxt')
    caffemodel = os.path.join(cfg.ROOT_DIR, 'data', 'fast_rcnn_models',
                              NETS[args.demo_net][1])

    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\nDid you run ./data/script/'
                       'fetch_fast_rcnn_models.sh?').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)

    print '\n\nLoaded network {:s}'.format(caffemodel)

    print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
    print 'Demo for data/demo/000004.jpg'
    demo(net, '000004', ('car',))

    print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'
    print 'Demo for data/demo/001551.jpg'
    demo(net, '001551', ('sofa', 'tvmonitor'))

    plt.show()

複製這個demo.py 修改成CarFaceTest.py,下面是修改後的檔案
修改後的檔案主要是添加了outputDetectionResult和runDetection兩個函式, 添加了部分註釋

#!/usr/bin/env python
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""

import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from utils.cython_nms import nms
from utils.timer import Timer
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse

#CLASSES = ('__background__','aeroplane','bicycle','bird','boat',
#       'bottle','bus','car','cat','chair','cow','diningtable','dog','horse'
#       'motorbike','person','pottedplant','sheep','sofa','train','tvmonitor')

CLASSES = ('__background__','car') #需要跟自己訓練的資料集中的類別一致,原來是21類的voc資料集,自己的資料集就是car和background

NETS = {'vgg16': ('VGG16',
                  'vgg16_fast_rcnn_iter_40000.caffemodel'),
        'vgg_cnn_m_1024': ('VGG_CNN_M_1024',
                           'vgg_cnn_m_1024_fast_rcnn_iter_40000.caffemodel'),
    'vgg_cnn_m_1024_louyihang': ('VGG_CNN_M_1024_LOUYIHANG',
               'vgg_cnn_m_1024_fast_rcnn_louyihang_iter_40000.caffemodel'),
        'caffenet': ('CaffeNet',
                     'caffenet_fast_rcnn_iter_40000.caffemodel'),
    'caffenet_louyihang':('CaffeNet_LOUYIHANG',
             'caffenet_fast_rcnn_louyihang_iter_40000.caffemodel'),
    'vgg16_louyihang':('VGG16_LOUYIHANG',
               'vgg16_fast_rcnn_louyihang_iter_40000.caffemodel')}#對映到對應的模型檔案

def outputDetectionResult(im, class_name, dets, thresh=0.5): #開啟相應的輸出檔案
    outputFile = open('CarDetectionResult.txt')
    inds = np.where(dets[:,-1] >= thresh)[0]
    if len(inds) == 0:
        return
def runDetection (net, basePath, testFileName,classes):#這個函式是自己後加的,取代了demo函式,給定測試資料列表
    ftest = open(testFileName,'r')
    imageFileName = basePath+'/' + ftest.readline().strip()
    num = 1
    outputFile = open('CarDetectionResult.txt','w')
    while imageFileName:
    print imageFileName
    print 'now is ', num
    num +=1
    imageFileBaseName = os.path.basename(imageFileName)
    imageFileDir = os.path.dirname(imageFileName)
    boxFileName = imageFileDir +'/'+imageFileBaseName.replace('.jpg','_boxes.mat')
    print boxFileName
    obj_proposals = sio.loadmat(boxFileName)['boxes']
    #obj_proposals[:,2] = obj_proposals[:, 2] + obj_proposals[:, 0]#這裡也需要注意,OP裡面的座標資料是否為x1y1x2y2還是x1y1wh
    #obj_proposals[:,3] = obj_proposals[:, 3] + obj_proposals[:, 1]
    im = cv2.imread(imageFileName)
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im, obj_proposals)#檢測函式
    timer.toc()
    print ('Detection took {:.3f} for '
               '{:d} object proposals').format(timer.total_time, boxes.shape[0])
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3#NMS引數用來控制非極大值抑制
        for cls in classes:
            cls_ind = CLASSES.index(cls)
            cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
            cls_scores = scores[:, cls_ind]
            dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
            keep = nms(dets, NMS_THRESH)
            dets = dets[keep, :]
            print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,
                                                                    CONF_THRESH)
        inds = np.where(dets[:, -1] >= CONF_THRESH)[0]
        print 'inds.size', inds.size
        if len(inds) != 0:
            outputFile.write(imageFileName+' ')
        outputFile.write(str(inds.size)+' ')將檢測的結果寫出相應的檔案裡
            for i in inds:
            bbox = dets[i, :4]
            outputFile.write(str(int(bbox[0]))+' '+ str(int(bbox[1]))+' '+ str(int(bbox[2]))+' '+ str(int(bbox[3]))+' ')
            outputFile.write('\n')
        else:
            outputFile.write(imageFileName +' 0' '\n')
    temp = ftest.readline().strip()
    if temp:
        imageFileName = basePath+'/' + temp
    else:
        break
def vis_detections(im, class_name, dets, thresh=0.5):#這個函式需要加以說明,這個函式雖然沒有用,但是我的伺服器上沒有輸出裝置
    """Draw detected bounding boxes."""#因此要將部分用到顯示的函式給註釋掉,否則執行會報錯
    inds = np.where(dets[:, -1] >= thresh)[0]
    print 'inds.shape', inds.shape
    print inds
    print 'inds.size', inds.size
    if len(inds) == 0:
        return
        #im = im[:, :, (2, 1, 0)]
    #fig, ax = plt.subplots(figsize=(12, 12))
    #ax.imshow(im, aspect='equal')
    #for i in inds:
    #    bbox = dets[i, :4]
    #    score = dets[i, -1]

    #    ax.add_patch(
    #        plt.Rectangle((bbox[0], bbox[1]),
    #                      bbox[2] - bbox[0],
    #                      bbox[3] - bbox[1], fill=False,
    #                      edgecolor='red', linewidth=3.5)
    #        )
    #    ax.text(bbox[0], bbox[1] - 2,
    #            '{:s} {:.3f}'.format(class_name, score),
    #            bbox=dict(facecolor='blue', alpha=0.5),
    #            fontsize=14, color='white')

    #ax.set_title(('{} detections with '
    #              'p({} | box) >= {:.1f}').format(class_name, class_name,
    #                                              thresh),
    #              fontsize=14)
    #plt.axis('off')
    #plt.tight_layout()
    #plt.draw()

def demo(net, image_name, classes):#原來的demo函式,沒有修改
    """Detect object classes in an image using pre-computed object proposals."""

    # Load pre-computed Selected Search object proposals
    #box_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo',image_name + '_boxes.mat')
    basePath='/home/chenjie/DataSet/500CarTestDataSet2'
    box_file = os.path.join(basePath,image_name + '_boxes.mat')
    obj_proposals = sio.loadmat(box_file)['boxes']
    # Load the demo image
    #im_file = os.path.join(cfg.ROOT_DIR, 'data', 'demo', image_name + '.jpg')
    im_file = os.path.join(basePath, image_name + '.jpg')
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    scores, boxes = im_detect(net, im, obj_proposals)
    timer.toc()
    print ('Detection took {:.3f}s for '
           '{:d} object proposals').format(timer.total_time, boxes.shape[0])

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls in classes:
        cls_ind = CLASSES.index(cls)
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack((cls_boxes,
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        print 'All {} detections with p({} | box) >= {:.1f}'.format(cls, cls,
                                                                    CONF_THRESH)

        vis_detections(im, cls, dets, thresh=CONF_THRESH)

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
    parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',
                        default=0, type=int)
    parser.add_argument('--cpu', dest='cpu_mode',
                        help='Use CPU mode (overrides --gpu)',
                        action='store_true')
    parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',
                        choices=NETS.keys(), default='vgg16')

    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()

    prototxt = os.path.join(cfg.ROOT_DIR, 'models', NETS[args.demo_net][0],
                            'test.prototxt')
    #caffemodel = os.path.join(cfg.ROOT_DIR, 'data', 'fast_rcnn_models',
    #                          NETS[args.demo_net][1])
    #caffemodel = '/home/chenjie/fast-rcnn/output/default/KakouTrain/vgg16_fast_rcnn_louyihang_iter_40000.caffemodel'
    #caffemodel = '/home/chenjie/louyihang/fast-rcnn/output/default/KakouTrain/caffenet_fast_rcnn_louyihang_iter_40000.caffemodel'
    caffemodel = '/home/chenjie/fast-rcnn/output/default/KakouTrain/vgg_cnn_m_1024_fast_rcnn_louyihang_iter_40000.caffemodel'#我在這裡直接指定了訓練好的模型檔案,訓練好的模型檔案是在工程根目錄下的,output/default/對應的資料庫名字下面
    if not os.path.isfile(caffemodel):
        raise IOError(('{:s} not found.\nDid you run ./data/script/'
                       'fetch_fast_rcnn_models.sh?').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)

    print '\n\nLoaded network {:s}'.format(caffemodel)

    #demo(net, 'Target0/000001', ('car',))
    #輸入對應的測試影象列表,需要在同級目錄下襬放同名的_boxes.mat檔案,它會自動的替換字尾名!
    #runDetection(net, '/home/chenjie/DataSet/temptest','/home/chenjie/DataSet/temptest/Imagelist.txt',('car',))
    runDetection(net, '/home/chenjie/DataSet/500CarTestDataSet2','/home/chenjie/DataSet/500CarTestDataSet2/Imagelist.txt',('car',))
    #runDetection(net, '/home/chenjie/DataSet/Kakou_Test_Scale0.25/','/home/chenjie/DataSet/Kakou_Test_Scale0.25/imagelist.txt',('car',))
    #runDetection(net, '/home/chenjie/DataSet/Images_Version1_Test_Boxes','/home/chenjie/DataSet/Images_Version1_Test_Boxes/ImageList_Version1_List.txt',('car',))
    #plt.show()

5.檢測結果

訓練資料集

首先給出我的訓練資料集,其實我的訓練資料集並不是太複雜的

測試資料集

輸出檢測結果到txt檔案中,

測試效果

**在複雜場景下的測試效果非常好,速度也非常快,中型網路監測平均每張在K80顯示卡下時0.1~0.2S左右,影象的尺寸是480*640,6000張測試資料集下達到的準確率是98%!!!**