1. 程式人生 > >TensorFlow 物件檢測 API 教程5

TensorFlow 物件檢測 API 教程5

TensorFlow 物件檢測 API 教程 - 第5部分:儲存和部署模型

在本教程的這一步,認為已經選擇了預先訓練的物件檢測模型,調整現有的資料集或建立自己的資料集,並將其轉換為 TFRecord 檔案,修改模型配置檔案並開始訓練。但是,現在需要儲存模型並將其部署到專案中。

一. 將檢查點模型 (.ckpt) 儲存為 .pb 檔案

回到 TensorFlow 物件檢測資料夾,並將 export_inference_graph.py 檔案複製到包含模型配置檔案的資料夾中。


python export_inference_graph.py --input_type image_tensor
--pipeline_config_path ./rfcn_resnet101_coco.config --trained_checkpoint_prefix ./models/train/model.ckpt-5000 --output_directory ./fine_tuned_model

這將建立一個新的目錄 fine_tuned_model ,其中模型名為 frozen_inference_graph.pb

二.在專案中使用模型

在本指南中一直在研究的專案是建立一個交通燈分類器。在 Python 中,可以將這個分類器作為一個類來實現。在類的初始化部分中,可以建立一個 TensorFlow

會話,以便在每次需要分類時都不需要建立它。


class TrafficLightClassifier(object):
    def __init__(self):
        PATH_TO_MODEL = 'frozen_inference_graph.pb'
        self.detection_graph = tf.Graph()
        with self.detection_graph.as_default():
            od_graph_def = tf.GraphDef()
            # Works up to here.
            with
tf.gfile.GFile(PATH_TO_MODEL, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0') self.d_boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0') self.d_scores = self.detection_graph.get_tensor_by_name('detection_scores:0') self.d_classes = self.detection_graph.get_tensor_by_name('detection_classes:0') self.num_d = self.detection_graph.get_tensor_by_name('num_detections:0') self.sess = tf.Session(graph=self.detection_graph)

在這個類中,建立了一個函式,在影象上執行分類,並返回影象中分類的邊界框,分數和類。


def get_classification(self, img):
    # Bounding Box Detection.
    with self.detection_graph.as_default():
        # Expand dimension since the model expects image to have shape [1, None, None, 3].
        img_expanded = np.expand_dims(img, axis=0)  
        (boxes, scores, classes, num) = self.sess.run(
            [self.d_boxes, self.d_scores, self.d_classes, self.num_d],
            feed_dict={self.image_tensor: img_expanded})
    return boxes, scores, classes, num

此時,需要過濾低於指定分數閾值的結果。結果自動從最高分到最低分,所以這相當容易。用上面的函式返回分類結果,做完以上這些就完成了!

下面可以看到交通燈分類器在行動

ai-1