七扭八歪解faster rcnn(keras版)(四)
阿新 • • 發佈:2019-01-02
def get_data(input_path): all_imgs = {} classes_count = {} class_mapping = {} with open(input_path,'r') as f: print('Parsing annotation files') for line in f: line_split = line.strip().split(',') (filename,x1,y1,x2,y2,class_name) = line_split ifclass_name not in classes_count: classes_count[class_name] = 1 else: classes_count[class_name] += 1 if class_name not in class_mapping: class_mapping[class_name] = len(class_mapping) if filename not in all_imgs: all_imgs[filename] = {} img = cv2.imread(filename) (rows,cols) = img.shape[:2] all_imgs[filename]['filepath'] = filename all_imgs[filename]['width'] = cols all_imgs[filename]['height'] = rows all_imgs[filename]['bboxes'] = [] if np.random.randint(0,6) > 0: all_imgs[filename]['imageset'] = 'trainval' else: all_imgs[filename]['imageset'] = 'test' all_imgs[filename]['bboxes'].append({'class': class_name, 'x1': int(x1), 'x2': int(x2), 'y1': int(y1), 'y2': int(y2)}) all_data = [] for key in all_imgs: all_data.append(all_imgs[key]) classes_count['bg'] = 0 class_mapping['bg'] = len(class_mapping) random.shuffle(all_data) print('Training images per class ({} classes) :'.format(len(classes_count))) pprint.pprint(classes_count) return all_data, classes_count, class_mapping
程式碼很簡單,從命令列讀出input_path,然後分割出來檔案路徑和框的資訊以及後邊那個字串標註,分別記下來各種資訊,隨機標註上市訓練資料還是測試資料,返回
def classifier(base_layers, input_rois, num_rois, nb_classes = 21, trainable=False): pooling_regions = 14 input_shape = (num_rois,14,14,1024) out_roi_pool = RoiPoolingConv(pooling_regions, num_rois)([base_layers, input_rois]) out = classifier_layers(out_roi_pool, input_shape=input_shape, trainable=True)將input_shape(num_rois預設設定為4,整個為4,14,14,1024)和out_roi_pool放入分類器裡,
關於timedistributed參看:
http://blog.csdn.net/xiaojiajia007/article/details/76665016