1. 程式人生 > >Faster RCNN程式碼詳解(三):資料處理的整體結構

Faster RCNN程式碼詳解(三):資料處理的整體結構

在上一篇部落格中介紹了Faster RCNN網路結構的構建:Faster RCNN程式碼詳解(二):網路結構構建。網路結構是Faster RCNN演算法中最重要兩部分之一,這篇部落格將介紹非常重要的另一部分:資料處理

資料處理是通過AnchorLoader類實現的,該類所在指令碼:~mx-rcnn/rcnn/core/loader.py,該類實現了資料處理的整體架構,是比較巨集觀的。細節方面是通過assign_anchor函式實現的,該函式實現了關於anchor生成、正負樣本界定等,程式碼在~/mx-rcnn/rcnn/io/rpn.py中,下一篇會介紹。

接下來就看看AnchorLoader類是怎麼實現資料讀取的吧(重點在於get_batch方法)。

class AnchorLoader(mx.io.DataIter):
    def __init__(self, feat_sym, roidb, batch_size=1, shuffle=False, ctx=None, work_load_list=None,
                 feat_stride=16, anchor_scales=(8, 16, 32), anchor_ratios=(0.5, 1, 2), allowed_border=0, aspect_grouping=False):
        """
        This Iter will provide roi data to Fast R-CNN network
        :param feat_sym: to infer shape of assign_output
        :param roidb: must be preprocessed
        :param batch_size: must divide BATCH_SIZE(128)
        :param shuffle: bool
        :param ctx: list of contexts
        :param work_load_list: list of work load
        :param aspect_grouping: group images with similar aspects
        :return: AnchorLoader
        """
super(AnchorLoader, self).__init__() # save parameters as properties self.feat_sym = feat_sym self.roidb = roidb self.batch_size = batch_size self.shuffle = shuffle self.ctx = ctx if self.ctx is None: self.ctx = [mx.cpu()] self.work_load_list = work_load_list self.feat_stride = feat_stride self.anchor_scales = anchor_scales self.anchor_ratios = anchor_ratios self.allowed_border = allowed_border self.aspect_grouping = aspect_grouping # infer properties from roidb
self.size = len(roidb) self.index = np.arange(self.size) # decide data and label names # 這部分定義的data_name、label_name和定義網路結構以及用module介面初始化model時定義的資料輸入是一一對應的。 # 關於網路結構的輸入可以參考~mx-rcnn/rcnn/symbol/symbol_resnet.py指令碼的get_resnet_train函式。 if config.TRAIN.END2END: self.data_name = ['data', 'im_info', 'gt_boxes'] else: self.data_name = ['data'] self.label_name = ['label', 'bbox_target', 'bbox_weight'] # status variable for synchronization between get_data and get_label self.cur = 0 self.batch = None self.data = None self.label = None # get first batch to fill in provide_data and provide_label # 初始化呼叫reset方法進行一些變數的重置,get_batch方法用來讀取第一個batch的資料, # get_batch方法非常重要,包含了資料讀取和處理相關的細節。 self.reset() self.get_batch() @property def provide_data(self): return [(k, v.shape) for k, v in zip(self.data_name, self.data)] @property def provide_label(self): return [(k, v.shape) for k, v in zip(self.label_name, self.label)] def reset(self): self.cur = 0 if self.shuffle: if self.aspect_grouping: widths = np.array([r['width'] for r in self.roidb]) heights = np.array([r['height'] for r in self.roidb]) horz = (widths >= heights) vert = np.logical_not(horz) horz_inds = np.where(horz)[0] vert_inds = np.where(vert)[0] inds = np.hstack((np.random.permutation(horz_inds), np.random.permutation(vert_inds))) extra = inds.shape[0] % self.batch_size inds_ = np.reshape(inds[:-extra], (-1, self.batch_size)) row_perm = np.random.permutation(np.arange(inds_.shape[0])) inds[:-extra] = np.reshape(inds_[row_perm, :], (-1,)) self.index = inds else: np.random.shuffle(self.index) def iter_next(self): return self.cur + self.batch_size <= self.size # next方法是資料迭代器每次迭代資料時候呼叫的,在該方法中還是先通過get_batch() # 得到一個batch資料,然後通過mx.io.DataBatch將資料封裝成指定格式作為模型的輸入。 def next(self): if self.iter_next(): self.get_batch() self.cur += self.batch_size return mx.io.DataBatch(data=self.data, label=self.label, pad=self.getpad(), index=self.getindex(), provide_data=self.provide_data, provide_label=self.provide_label) else: raise StopIteration def getindex(self): return self.cur / self.batch_size def getpad(self): if self.cur + self.batch_size > self.size: return self.cur + self.batch_size - self.size else: return 0 def infer_shape(self, max_data_shape=None, max_label_shape=None): """ Return maximum data and label shape for single gpu """ if max_data_shape is None: max_data_shape = [] if max_label_shape is None: max_label_shape = [] max_shapes = dict(max_data_shape + max_label_shape) input_batch_size = max_shapes['data'][0] im_info = [[max_shapes['data'][2], max_shapes['data'][3], 1.0]] _, feat_shape, _ = self.feat_sym.infer_shape(**max_shapes) label = assign_anchor(feat_shape[0], np.zeros((0, 5)), im_info, self.feat_stride, self.anchor_scales, self.anchor_ratios, self.allowed_border) label = [label[k] for k in self.label_name] label_shape = [(k, tuple([input_batch_size] + list(v.shape[1:]))) for k, v in zip(self.label_name, label)] return max_data_shape, label_shape # get_batch方法是讀取資料的主要方法,該方法包含anchor的初始化、anchor標籤的確定、 # 正負樣本的確定等。該方法在資料初始化的時候會直接呼叫一次用來讀取第一個batch的資料, # 之後通過next方法每次迭代讀取資料時候都會呼叫。 def get_batch(self): # slice roidb # 這部分是根據batch size的大小選擇對應數量的輸入資料。 cur_from = self.cur cur_to = min(cur_from + self.batch_size, self.size) roidb = [self.roidb[self.index[i]] for i in range(cur_from, cur_to)] # decide multi device slice work_load_list = self.work_load_list ctx = self.ctx if work_load_list is None: work_load_list = [1] * len(ctx) assert isinstance(work_load_list, list) and len(work_load_list) == len(ctx), \ "Invalid settings for work load. " slices = _split_input_slice(self.batch_size, work_load_list) # get testing data for multigpu data_list = [] label_list = [] for islice in slices: iroidb = [roidb[i] for i in range(islice.start, islice.stop)] # get_rpn_batch()會對輸入影象做短邊resize到指定尺寸(預設是600),另外長邊最大值 # 是1000,所以在對短邊做resize後如果長邊超過1000,則以長邊resize到1000為準 # (短邊從600按對應比例繼續縮小)。需要注意的是box的標註座標也會做對應的縮放。 # 得到的資料就放在data_list中,標註資訊就放在label_list中 data, label = get_rpn_batch(iroidb) data_list.append(data) label_list.append(label) # pad data first and then assign anchor (read label) data_tensor = tensor_vstack([batch['data'] for batch in data_list]) for data, data_pad in zip(data_list, data_tensor): data['data'] = data_pad[np.newaxis, :] new_label_list = [] for data, label in zip(data_list, label_list): # infer label shape data_shape = {k: v.shape for k, v in data.items()} del data_shape['im_info'] # self.feat_sym.infer_shape(**data_shape)是計算指定size的資料(data_shape:{'data':(1,3,600,800)}) # 通過指定symbol(self.feat_sym)得到的輸出size(feat_shape)。infer_shape方法的輸入除了這種形式, # 還可以用self.feat_sym.infer_shape(data=(1,3,600,800)),這裡關鍵字data是在網路結構中定義的輸入層名稱。 _, feat_shape, _ = self.feat_sym.infer_shape(**data_shape) feat_shape = [int(i) for i in feat_shape[0]] # add gt_boxes to data for e2e data['gt_boxes'] = label['gt_boxes'][np.newaxis, :, :] # assign anchor for label # assign_anchor函式是給anchor分配標籤的操作。輸入中feat_shape是用於生成anchor的feature map維度, # list格式,比如1*18*38*50,18是2*9的意思,9是anchor數量,2是背景和非背景2個類。 # label['gt_boxes']是x*5的numpy array,表示x個object的座標和類別資訊,是標註資訊, # 也就是ground truth,標註座標是和影象大小對應的,ground truth主要用在anchor標籤定義上。 # data['im_info']是1*3的numpy array,表示影象大小和縮放尺度資訊。 # self.feat_stride是指特徵縮放比例,比如16。self.anchor_scales預設是[8,16,32]。 # self.anchor_ratios預設是[0.5,1,2]。輸出label是包含3個鍵值對的字典,分別是label["label"]、 # label["bbox_target"]、label["bbox_weights"],這3個值都在RPN網路中用到 。 # 該函式的細節在~/mx-rcnn/rcnn/io/rpn.py中。 label = assign_anchor(feat_shape, label['gt_boxes'], data['im_info'], self.feat_stride, self.anchor_scales, self.anchor_ratios, self.allowed_border) new_label_list.append(label) all_data = dict() for key in self.data_name: all_data[key] = tensor_vstack([batch[key] for batch in data_list]) all_label = dict() for key in self.label_name: pad = -1 if key == 'label' else 0 all_label[key] = tensor_vstack([batch[key] for batch in new_label_list], pad=pad) # 最後返回的data是長度為3的列表,列表中每個值都是NDArray,分別是4維的影象內容 # data:(1,3,600,800);2維的影象寬高和scale資訊im_info:(1,3); # 3維的原始bounding box標註資訊:gt_boxes:(1,x,5),x是object的數量。 # lable也是長度為3的列表,列表中每個值都是NDArray,分別是2維的anchor標籤資訊label:(1,17100); # 4維的anchor座標迴歸target資訊bbox_target:(1,36,38,50), # 4維的anchor座標權重資訊bbox_weight:(1,36,38,50)。 self.data = [mx.nd.array(all_data[key]) for key in self.data_name] self.label = [mx.nd.array(all_label[key]) for key in self.label_name]

這篇部落格介紹了Faster RCNN演算法中關於資料處理的整體結構,比較巨集觀,重要內容都在get_batch方法中,get_batch方法描述了Faster RCNN演算法對資料讀取的整體結構,這在很多後續的演算法中都通用,因此當你瞭解了資料讀取的整體結構後,就能舉一反三了。

在get_batch方法中有個關於anchor的函式:assign_anchor。這應該也是本系列部落格第一次提到anchor。anchor在Faster RCNN中是非常重要的概念,但很多新手對anchor的理解可能模稜兩可,比如anchor是什麼?怎麼生成的?anchor的標籤是怎麼定義的?bbox(bounding box)的迴歸目標是怎麼定義的?bbox和anchor是什麼區別?如果你有這些疑問,那麼下一篇部落格我將為你解開anchor這個神祕的面紗:Faster RCNN程式碼詳解(四):關於anchor的前世今生