1. 程式人生 > >Tensorflow object detection API 原始碼閱讀筆記:架構

Tensorflow object detection API 原始碼閱讀筆記:架構

在之前的博文中介紹過用tf提供的預訓練模型進行inference,非常簡單。這裡我們深入原始碼,瞭解檢測API的程式碼架構,每個部分的深入閱讀留待後續。

'''構建自己模型的介面是虛基類DetectionModel,具體有5個抽象函式需要實現。
'''
object_detection/core/model.py
  def groundtruth_lists(self, field):
    """Access list of groundtruth tensors."""
  def groundtruth_has_field(self, field):
    """Determines whether the groundtruth includes the given field."
"" def provide_groundtruth(self, groundtruth_boxes_list, groundtruth_classes_list, groundtruth_masks_list=None, groundtruth_keypoints_list=None): """Provide groundtruth tensors.""" @abstractmethod
def preprocess(self, inputs): @abstractmethod def predict(self, preprocessed_inputs) @abstractmethod def postprocess(self, prediction_dict, **params) @abstractmethod def loss(self, prediction_dict) @abstractmethod def restore_map(self, from_detection_checkpoint=True)
object_detection/meta_architectures/faster_rcnn_meta_arch.py

class
FasterRCNNFeatureExtractor(object):
"""Faster R-CNN Feature Extractor definition.""" def __init__(self, is_training, first_stage_features_stride, batch_norm_trainable=False, reuse_weights=None, weight_decay=0.0) @abstractmethod def preprocess(self, resized_inputs): """Feature-extractor specific preprocessing (minus image resizing).""" def extract_proposal_features(self, preprocessed_inputs, scope): """Extracts first stage RPN features.""" @abstractmethod def _extract_proposal_features(self, preprocessed_inputs, scope): def extract_box_classifier_features(self, proposal_feature_maps, scope): """Extracts second stage box classifier features.""" @abstractmethod def _extract_box_classifier_features(self, proposal_feature_maps, scope): """Extracts second stage box classifier features, to be overridden.""" def restore_from_classification_checkpoint_fn( self, first_stage_feature_extractor_scope, second_stage_feature_extractor_scope): """Returns a map of variables to load from a foreign checkpoint.""" class FasterRCNNMetaArch(model.DetectionModel): """Faster R-CNN Meta-architecture definition.""" """暫時主要看哪些地方呼叫了feature_extractor: A FasterRCNNFeatureExtractor object.換一個cnn還是比較簡單的,只需要重寫一個faster_rcnn_new_cnn_feature_extractor。最終構建的檢測模型是這個類的物件。""" def preprocess(self, inputs): """For Faster R-CNN, we perform image resizing in the base class --- each class subclassing FasterRCNNMetaArch is responsible for any additional preprocessing (e.g., scaling pixel values to be in [-1, 1]). 見下面程式碼塊中實現的preprocess函式"""
object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py
"""這一塊和slim結合緊密,我們仔細看看。
"""

class FasterRCNNResnetV1FeatureExtractor(
    faster_rcnn_meta_arch.FasterRCNNFeatureExtractor):
  """Faster R-CNN Resnet V1 feature extractor implementation."""
    def __init__(self,
               architecture,
               resnet_model,
               is_training,
               first_stage_features_stride,
               batch_norm_trainable=False,
               reuse_weights=None,
               weight_decay=0.0):

    def preprocess(self, resized_inputs):
    """Faster R-CNN Resnet V1 preprocessing."""
        channel_means = [123.68, 116.779, 103.939]
        return resized_inputs - [[channel_means]]

    def _extract_proposal_features(self, preprocessed_inputs, scope):
    """Extracts first stage RPN features.
    使用endpoints輸出resnet block3的值。
    """

    def _extract_box_classifier_features(self, proposal_feature_maps, scope):
    """Extracts second stage box classifier features.
    拆分出resnet的block4。注意variable_scope和arg_scope的使用。
    """

class FasterRCNNResnet152FeatureExtractor(FasterRCNNResnetV1FeatureExtractor):
  """Faster R-CNN Resnet 152 feature extractor implementation."""

  def __init__(self,
               is_training,
               first_stage_features_stride,
               batch_norm_trainable=False,
               reuse_weights=None,
               weight_decay=0.0):
    """Constructor.
    Args:
      is_training: See base class.
      first_stage_features_stride: See base class.
      batch_norm_trainable: See base class.
      reuse_weights: See base class.
      weight_decay: See base class.
    Raises:
      ValueError: If `first_stage_features_stride` is not 8 or 16,
        or if `architecture` is not supported.
    """
    super(FasterRCNNResnet152FeatureExtractor, self).__init__(
        'resnet_v1_152', resnet_v1.resnet_v1_152, is_training,
        first_stage_features_stride, batch_norm_trainable,
        reuse_weights, weight_decay)
    """往前看各個類的init,'resnet_v1_152', resnet_v1.resnet_v1_152只用在了上面的class FasterRCNNResnetV1FeatureExtractor"""

同樣建議跑一跑test指令碼。會遇到如下檔案,按照test中出現的順序逐個閱讀這些檔案,以及對應的test指令碼。

"""Builder function to construct tf-slim arg_scope for convolution, fc ops.
看一下這個指令碼的test,很容易理解超引數配置是怎麼讀取的了,類似OpenFOAM中的dict。object_detection.protos.hyperparams_pb2.Hyperparams。
"""
from object_detection.builders import hyperparams_builder

"""Contains routines for printing protocol messages in text format.
同樣是上面這個test指令碼,目前主要用在    
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
其中conv_hyperparams_text_proto是包含引數配置的字串,conv_hyperparams_proto是hyperparams.proto object,hyperparams_builder.build的第一個引數。
"""
from google.protobuf import text_format

"""Function to build box predictor from configuration.
Box predictors are classes that take a high level
image feature map as input and produce two predictions,
(1) a tensor encoding box locations, and
(2) a tensor encoding classes for each box.
object_detection/core/box_predictor.py留待後續研讀。注意conv_hyperparams_text_proto是放進box_predictor_text_proto然後一起傳遞給class ConvolutionalBoxPredictor(BoxPredictor)的。
"""
from object_detection.builders import box_predictor_builder

"""Generates grid anchors on the fly as used in Faster RCNN.
下次細看。
"""
from object_detection.anchor_generators import grid_anchor_generator

"""Builder function for post processing operations."""
from object_detection.builders import post_processing_builder

"""Classification and regression loss functions for object detection."""
from object_detection.core import losses

"""proto檔案,下次再結合相應的core和builder來具體研究如何編寫和讀取這些檔案"""
from object_detection.protos import box_predictor_pb2
from object_detection.protos import hyperparams_pb2
from object_detection.protos import post_processing_pb2
"""A function to build a DetectionModel from configuration.
很多內容在faster_rcnn_meta_arch_test_lib.py測試過了。
"""
object_detection/builders/model_builder.py