1. 程式人生 > >七扭八歪解faster rcnn(keras版)(四)

七扭八歪解faster rcnn(keras版)(四)

def get_data(input_path):
    all_imgs = {}
    classes_count = {}
    class_mapping = {}
    with open(input_path,'r') as f:
        print('Parsing annotation files')
        for line in f:
            line_split = line.strip().split(',')
            (filename,x1,y1,x2,y2,class_name) = line_split

            if 
class_name not in classes_count: classes_count[class_name] = 1 else: classes_count[class_name] += 1 if class_name not in class_mapping: class_mapping[class_name] = len(class_mapping) if filename not in all_imgs: all_imgs[filename] = {} img = cv2.imread(filename) (rows,
cols) = img.shape[:2] all_imgs[filename]['filepath'] = filename all_imgs[filename]['width'] = cols all_imgs[filename]['height'] = rows all_imgs[filename]['bboxes'] = [] if np.random.randint(0,6) > 0: all_imgs[filename]['imageset'
] = 'trainval' else: all_imgs[filename]['imageset'] = 'test' all_imgs[filename]['bboxes'].append({'class': class_name, 'x1': int(x1), 'x2': int(x2), 'y1': int(y1), 'y2': int(y2)}) all_data = [] for key in all_imgs: all_data.append(all_imgs[key]) classes_count['bg'] = 0 class_mapping['bg'] = len(class_mapping) random.shuffle(all_data) print('Training images per class ({} classes) :'.format(len(classes_count))) pprint.pprint(classes_count) return all_data, classes_count, class_mapping

程式碼很簡單,從命令列讀出input_path,然後分割出來檔案路徑和框的資訊以及後邊那個字串標註,分別記下來各種資訊,隨機標註上市訓練資料還是測試資料,返回

def  classifier(base_layers, input_rois, num_rois, nb_classes = 21, trainable=False):

    pooling_regions = 14
input_shape = (num_rois,14,14,1024)
    out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois])
    out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)
將input_shape(num_rois預設設定為4,整個為4,14,14,1024)和out_roi_pool放入分類器裡,

關於timedistributed參看:

http://blog.csdn.net/xiaojiajia007/article/details/76665016