1. 程式人生 > >用自己的資料訓練Faster-RCNN,tensorflow版本(一)

用自己的資料訓練Faster-RCNN,tensorflow版本(一)

我用的Faster-RCNN是tensorflow版本,fork自githubFaster-RCNN_TF

1.1、環境配置

按照該專案中的README.md ,將需要的幾個依賴cython, python-opencv, easydict都安裝好,並確保本地計算機中有tensorflow,沒有的話自行安裝;

1.2、克隆工程:在本地計算機的終端輸入

git clone --recursive https://github.com/smallcorgi/Faster-RCNN_TF.git

下載下來的內容都在目錄 Faster-RCNN_TF 下;

1.3、編譯Cython模組

cd $FRCN_ROOT/lib # 首先進入目錄Faster-RCNN_TF/lib
make #編譯

編譯成功之後,目錄Faster-RCNN_TF/lib/nms 和 Faster-RCNN_TF/lib/roi_pooling_layer/ 和Faster-RCNN_TF/lib/utils下面會出現一些.so檔案。

注意:如果在這時候,你將該工程原封不動的連帶著.so檔案一起移植到了另一臺電腦上,想重新執行程式的時候,記住,要先刪除這幾個.so檔案,並重新進行編譯。因為編譯生成的檔案是隻適應本臺計算機的,換一臺計算機之後,用原來的.so檔案,就行不通了,程式會出錯。並且,必須要先刪除舊的.so檔案,否則就會呼叫舊的.so檔案,而不生成新的.so檔案。

2、介紹一下pascal_voc資料集的資料讀寫介面

工程Faster-RCNN_TF中讀取資料的介面都在目錄Faster-RCNN_TF/lib/datasets下。

原工程提供了5種資料來訓練網路,並分別給出了各自的資料讀寫介面。
5種資料分別是pascal_voc,coco,kitti,nissan,nthu,各自的資料讀寫介面分別是Faster-RCNN_TF/lib/datasets 中的pascal_voc.py,coco.py,kitti.py,nissan.py,nthu.py。

我們可以看到Faster-RCNN_TF/lib/datasets目錄下還有一些.py檔案,分別是:
factory.py

:是個工廠類,用類生成imdb類並且返回資料庫供網路訓練和測試使用
imdb.py:是資料庫讀寫類的基類,分裝了許多db的操作,具體的一些檔案讀寫需要繼承繼續讀寫

我們要用自己的資料進行訓練,就得編寫自己資料的讀寫介面,下面參考pascal_voc.py來編寫。

2.1、首先說明一下pascal_voc資料集的格式

以VOC2007為例,資料都放在一個叫做VOCdevkit的目錄中,我們來看一下目錄VOCdevkit的結構:

VOCdevkit/
VOCdevkit/VOC2007/
VOCdevkit/VOC2007/Annotations #所有圖片的XML檔案,一張圖片對應一個XML檔案,XML檔案中給出的圖片gt的形式是左上角和右下角的座標

VOCdevkit/VOC2007/ImageSets/          
VOCdevkit/VOC2007/ImageSets/Layout #裡面有三個txt檔案,分別是train.txt,trainval.txt,val.txt,儲存的分別是訓練圖片的名字列表,訓練驗證集的圖片名字列表,驗證集圖片的名字列表(名字均沒有.jpg字尾)
VOCdevkit/VOC2007/ImageSets/Main
VOCdevkit/VOC2007/ImageSets/Segmentation

VOCdevkit/VOC2007/JPEGImages  #所有的圖片

VOCdevkit/VOC2007/SegmentationClass  #segmentations by class

VOCdevkit/VOC2007/SegmentationObject  #segmentations by object

Faster-RCNN_TF工程主要用到的是目錄Annotations中的XML檔案、目錄JPEGImages中的圖片、目錄ImageSets/Layout中的txt檔案。

2.2、然後解釋一下pascal_voc.py中每個的函式的作用
主函式 if name == ‘main在檔案pascal_voc.py的最下面

if __name__ == '__main__':
    from datasets.pascal_voc import pascal_voc
    d = pascal_voc('trainval', '2007') #pascal_voc是一個類
    res = d.roidb
    from IPython import embed; embed()

類 pascal_voc中的函式:
class pascal_voc(imdb):
def init(self, image_set, year, devkit_path=None)在檔案pascal_voc.py的最上面
是初始化函式,對應著的是pascal_voc的資料集訪問格式
(我會按照這個初始化函式裡面用到的子函式的順序來介紹每個子函式的作用,這樣看比較直觀。在這個初始化函式init中用到的每個子函式我都會有一個標號,方便介紹。)

'''
是初始化函式,對應著的是pascal_voc的資料集訪問格式

:param image_set: 是一個str,值是'train'或者'test'或者'trainval'或者'val',表示的意思是用(訓練集)或者(測試集)或者(訓練驗證集)或者(驗證集)裡面的資料;
:param year: 是一個str,是VOC資料的年份,值是'2007'或者'2012'
:param devkit_path: pascal_voc資料集所在的路徑
'''
'''
以下的image_set都以train為例
year都以2007為例
'''
def __init__(self, image_set, year, devkit_path=None): 
    imdb.__init__(self, 'voc_' + year + '_' + image_set) # 繼承了類imdb的初始化函式__init__(),傳進去的引數是voc_2007_train。類imdb在Faster-RCNN_TF_R2/lib/datasets/imdb.py裡面被定義
    self._year = year #年份,比如2007
    self._image_set = image_set # train 
    self._devkit_path = self._get_default_path() if devkit_path is None else devkit_path # 這個路徑是pascal_voc資料集所在的路徑。如果devkit_path is None,返回pascal_voc的預設路徑:目錄VOCdevkit;如果devkit_path有值,則返回devkit_path。預設路徑用函式_get_default_path()獲得,標號(1)
    self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)#就是VOCdevkit/VOC2007
    self._classes = ('__background__', # always index 0
                     'aeroplane', 'bicycle', 'bird', 'boat',
                     'bottle', 'bus', 'car', 'cat', 'chair',
                     'cow', 'diningtable', 'dog', 'horse',
                     'motorbike', 'person', 'pottedplant',
                     'sheep', 'sofa', 'train', 'tvmonitor') #資料集中所包含的全部的object類別
    self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) # 構建字典{'__background__':'0','aeroplane':'1', 'bicycle':'2', 'bird':'3', 'boat':'4','bottle':'5', 'bus':'6', 'car':'7', 'cat':'8', 'chair':'9','cow':'10', 'diningtable':'11', 'dog':'12', 'horse':'13','motorbike':'14', 'person':'15', 'pottedplant':'16','sheep':'17', 'sofa':'18', 'train':'19', 'tvmonitor':'20'}  self.num_classes是object的類別總數21(背景也算一類),這個函式繼承自Faster-RCNN_TF_R2/lib/datasets/imdb.py
    self._image_ext = '.jpg' # 圖片字尾名
    self._image_index = self._load_image_set_index() #載入了樣本的list檔案,標號(2)
    # Default to roidb handler
    #self._roidb_handler = self.selective_search_roidb #當沒有RPN的時候,讀取並返回候選框ROI的db。函式selective_search_roidb是fast-rcnn提取候選框的方式(fast-rcnn沒有RPN),下面會具體講
    self._roidb_handler = self.gt_roidb # 當有RPN的時候,讀取並返回圖片gt的db。函式gt_roidb裡面並沒有提取圖片的ROI,因為faster-rcnn有RPN,用RPN來提取ROI。函式gt_roidb返回的是圖片的gt。標號(3)
    self._salt = str(uuid.uuid4())
    self._comp_id = 'comp4'

    # PASCAL specific config options
    self.config = {'cleanup'     : True,
                   'use_salt'    : True,
                   'use_diff'    : False,
                   'matlab_eval' : False,
                   'rpn_file'    : None,
                   'min_size'    : 2}

    assert os.path.exists(self._devkit_path), \
        'VOCdevkit path does not exist: {}'.format(self._devkit_path) #如果路徑self._devkit_path(也就是目錄VOCdevkit)不存在,退出
    assert os.path.exists(self._data_path), \
        'Path does not exist: {}'.format(self._data_path)  #如果路徑self._data_path(也就是VOCdevkit/VOC2007)不存在,退出

標號(1)def _get_default_path(self)

def _get_default_path(self):
    """
    Return the default path where PASCAL VOC is expected to be installed.
    返回資料集pascal_voc的預設路徑:Faster-RCNN_TF/data/VOCdevkit/2007
    """
    return os.path.join(cfg.DATA_DIR, 'VOCdevkit') # cfg.DATA_DIR是在Faster-RCNN_TF/lib/fast_rcnn/config.py裡面定義的,

Faster-RCNN_TF/lib/fast_rcnn/config.py中定義DATA_DIR的地方是這樣的(在220-224行):

# Root directory of project
__C.ROOT_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) # 這個路徑指的就是目錄Faster-RCNN_TF

# Data directory
__C.DATA_DIR = osp.abspath(osp.join(__C.ROOT_DIR, 'data')) # 這個路徑是Faster-RCNN_TF/data

標號(2)def _load_image_set_index(self)

def _load_image_set_index(self):
    """
    Load the indexes listed in this dataset's image set file.
    得到一個list,這個list裡面是集合self._image_set中所有圖片的名字(注意,圖片名字沒有後綴.jpg)
    """
    image_set_file = os.path.join(self._data_path, 'ImageSets', 'Layout',
                                  self._image_set + '.txt') 
    # image_set_file就是Faster-RCNN_TF/data/VOCdevkit/VOC2007/ImageSets/Layout/train.txt
    #之所以要讀這個train.txt檔案,是因為train.txt檔案裡面寫的是集合train中所有圖片的名字(沒有後綴.jpg)
    assert os.path.exists(image_set_file), \
            'Path does not exist: {}'.format(image_set_file)
    with open(image_set_file) as f: # 讀上面的train.txt檔案
        image_index = [x.strip() for x in f.readlines()] #將train.txt的內容(圖片名字)讀取出來放在image_index裡面
    return image_index #得到image_set裡面所有圖片的名字(沒有後綴.jpg)

標號(3)def gt_roidb(self)

def gt_roidb(self):
    """
    Return the database of ground-truth regions of interest.
    讀取並返回圖片gt的db。這個函式就是將圖片的gt載入進來。
    其中,pascal_voc圖片的gt資訊在XML檔案中(這個XML檔案是pascal_voc資料集本身提供的)
    並且,圖片的gt被提前放在了一個.pkl檔案裡面。(這個.pkl檔案需要我們自己生成,程式碼就在該函式中)

    This function loads/saves from/to a cache file to speed up future calls.
    之所以會將圖片的gt提前放在一個.pkl檔案裡面,是為了不用每次都再重新讀圖片的gt,直接載入這個檔案就可以了,可以提升速度。

    """
    cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl') #給.pkl檔案起個名字。引數self.cache_path和self.name繼承自類imdb,類imdb在Faster-RCNN_TF_R2/lib/datasets/imdb.py裡面被定義
    if os.path.exists(cache_file): # 如果這個.pkl檔案存在(說明之前執行過本函式,生成了這個pkl檔案)
        with open(cache_file, 'rb') as fid: #開啟
            roidb = cPickle.load(fid) #將裡面的資料載入進來
        print '{} gt roidb loaded from {}'.format(self.name, cache_file)
        return roidb #返回

    # 如果這個.pkl檔案不存在,說明是第一次執行本函式。
    gt_roidb = [self._load_pascal_annotation(index) 
                for index in self.image_index] #那麼首先要做的就是獲取圖片的gt,函式_load_pascal_annotation的作用就是獲取圖片gt。標號(4)
    with open(cache_file, 'wb') as fid:
        cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL) #將圖片的gt儲存在.pkl檔案裡面
    print 'wrote gt roidb to {}'.format(cache_file)

    return gt_roidb

引數self.cache_path和self.name繼承自類imdb,類imdb在Faster-RCNN_TF_R2/lib/datasets/imdb.py裡面被定義。類imdb中定義函式self.cache_path的地方在imdb.py中的69-74行:

@property
def cache_path(self):
    cache_path = osp.abspath(osp.join(cfg.DATA_DIR, 'cache'))  # 該路徑是Faster-RCNN_TF/data/cache
    if not os.path.exists(cache_path):
        os.makedirs(cache_path)
    return cache_path

類imdb中定義函式self.name的地方在imdb.py中的21-36行:

def __init__(self, name): #是類imdb的初始化函式,在pascal_voc.py的第26行被用到
    # name是形參,傳進來的引數是'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'
    self._name = name # 'voc_2007_train' or ‘voc_2007_test’ or 'voc_2007_val' or 'voc_2007_trainval'
    self._num_classes = 0
    self._classes = []
    self._image_index = []
    self._obj_proposer = 'selective_search'
    self._roidb = None
    print self.default_roidb
    self._roidb_handler = self.default_roidb  # self._roidb_handler在Faster-RCNN_TF/lib/datasets/icdar_2015.py中,又被重新定義了
    # Use this dict for storing dataset specific config options
    self.config = {}

@property
def name(self): #類imdb中定義函式self.name的地方
    return self._name #返回的是本檔案imdb.py中的self._name,往上面看

注意:如果你再次訓練的時候修改了train資料庫,增加或者刪除了一些資料,再想重新訓練的時候,一定要先刪除這個.pkl檔案!!!!!!因為如果不刪除的話,就會自動載入舊的pkl檔案,而不會生成新的pkl檔案。一定別忘了!

標號(4)def _load_pascal_annotation(self, index):這個函式是讀取圖片gt的具體實現

def _load_pascal_annotation(self, index):
   """
   :param index: 一張圖片的名字(沒有後綴.jpg)
   Load image and bounding boxes info from XML file in the PASCAL VOC
   format.從XML檔案中獲取圖片資訊和gt。
   這個XML檔案儲存的是PASCAL VOC圖片的資訊和gt的資訊,我們在下載VOC資料集的時候,XML檔案是一塊下載下來的。在資料夾Annotation裡面。
   """
   filename = os.path.join(self._data_path, 'Annotations', index + '.xml') #這個filename就是一個XML檔案的路徑,其中index是一張圖片的名字(沒有後綴)。例如VOCdevkit/VOC2007/Annotations/000005.xml
   tree = ET.parse(filename)
   objs = tree.findall('object')
   if not self.config['use_diff']:
       # Exclude the samples labeled as difficult
       non_diff_objs = [
           obj for obj in objs if int(obj.find('difficult').text) == 0]
       # if len(non_diff_objs) != len(objs):
       #     print 'Removed {} difficult objects'.format(
       #         len(objs) - len(non_diff_objs))
       objs = non_diff_objs
   num_objs = len(objs)  # 輸進來的圖片上的物體object的個數

   boxes = np.zeros((num_objs, 4), dtype=np.uint16)
   gt_classes = np.zeros((num_objs), dtype=np.int32)
   overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
   # "Seg" area for pascal is just the box area
   seg_areas = np.zeros((num_objs), dtype=np.float32)

   # Load object bounding boxes into a data frame.
   for ix, obj in enumerate(objs): # 對於該圖片上每一個object
       bbox = obj.find('bndbox') # pascal_voc的XML檔案中給出的圖片gt的形式是左上角和右下角的座標
       # Make pixel indexes 0-based
       x1 = float(bbox.find('xmin').text) - 1 
       y1 = float(bbox.find('ymin').text) - 1
       x2 = float(bbox.find('xmax').text) - 1
       y2 = float(bbox.find('ymax').text) - 1 #為什麼要減去1?是因為VOC的資料,座標-1,預設座標從0開始(這個還有待商榷,先忽略)
       cls = self._class_to_ind[obj.find('name').text.lower().strip()]#找到該object的類別
       boxes[ix, :] = [x1, y1, x2, y2]
       gt_classes[ix] = cls
       overlaps[ix, cls] = 1.0
       seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1) # seg_areas[ix]是該object gt的面積

   overlaps = scipy.sparse.csr_matrix(overlaps)

   return {'boxes' : boxes,
           'gt_classes': gt_classes,
           'gt_overlaps' : overlaps,
           'flipped' : False,
           'seg_areas' : seg_areas}

分析到現在,pascal_voc.py中還剩下一些函式,這些函式並沒有在pascal_voc.py裡面用到,但是在別的地方用到了,下面也分析一下:

def image_path_at(self, i)

'''
根據第i個影象樣本返回其對應的path,其呼叫了image_path_from_index(self, index)作為其具體實現;
'''
def image_path_at(self, i):
    """
    Return the absolute path to image i in the image sequence.
    """
    return self.image_path_from_index(self._image_index[i])

def image_path_from_index(self, index)

def image_path_from_index(self, index):
    """
    :param index: 是一張圖片的名字,假如說有一張圖片叫lsq.jpg,這個值就是lsq,沒有後綴名

    Construct an image path from the image's "index" identifier.
    返回圖片所在的路徑
    """
    image_path = os.path.join(self._data_path, 'JPEGImages',
                              index + self._image_ext) #這個就是圖片本身所在的路徑。其中index是一張圖片的名字(沒有後綴),_image_ext是圖片字尾名.jpg。例如VOCdevkit/VOC2007/JPEGImages/000005.jpg
    assert os.path.exists(image_path), \
            'Path does not exist: {}'.format(image_path) # 如果該路徑不存在,退出
    return image_path

def selective_search_roidb(self)

def selective_search_roidb(self):
    """
    Return the database of selective search regions of interest.
    Ground-truth ROIs are also included.
    沒有RPN的fast-rcnn提取候選框的方式。返回的是提取出來的ROI以及圖片的gt。
    這個函式在Faster-RCNN裡面用不到,在fast-rcnn裡面才會用到

    This function loads/saves from/to a cache file to speed up future calls.
    """
    cache_file = os.path.join(self.cache_path,
                              self.name + '_selective_search_roidb.pkl')

    if os.path.exists(cache_file):
        with open(cache_file, 'rb') as fid:
            roidb = cPickle.load(fid)
        print '{} ss roidb loaded from {}'.format(self.name, cache_file)
        return roidb

    if int(self._year) == 2007 or self._image_set != 'test':
        gt_roidb = self.gt_roidb()
        ss_roidb = self._load_selective_search_roidb(gt_roidb)
        roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
    else:
        roidb = self._load_selective_search_roidb(None)
    with open(cache_file, 'wb') as fid:
        cPickle.dump(roidb, fid, cPickle.HIGHEST_PROTOCOL)
    print 'wrote ss roidb to {}'.format(cache_file)

    return roidb

def _load_selective_search_roidb(self, gt_roidb)

def _load_selective_search_roidb(self, gt_roidb):
    '''
    載入預選框的檔案
    這個函式在Faster-RCNN裡面用不到,在fast-rcnn裡面才會用到。這個我還沒有研究
    '''
     filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
                                             'selective_search_data',
                                             self.name + '.mat'))
     assert os.path.exists(filename), \
            'Selective search data not found at: {}'.format(filename)
     raw_data = sio.loadmat(filename)['boxes'].ravel()

     box_list = []
     for i in xrange(raw_data.shape[0]):
         boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
         keep = ds_utils.unique_boxes(boxes)
         boxes = boxes[keep, :]
         keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
         boxes = boxes[keep, :]
         box_list.append(boxes)

     return self.create_roidb_from_box_list(box_list, gt_roidb)

3、編寫自己的資料讀寫介面

我們要用自己的資料進行訓練,就得編寫自己資料的讀寫介面,下面參考pascal_voc.py來編寫。根據上面對pascal_voc.py檔案的分析,發現,pascal_voc.py用了非常多的路徑拼接,很麻煩,我們不用這麼麻煩,只要設定好自己資料的路徑就可以了。

詳情見下篇。