1. 程式人生 > >使用TensorFlow Object Detection API進行影象物體檢測

使用TensorFlow Object Detection API進行影象物體檢測

  1. 匯出模型

    訓練完成後得到一些checkpoint檔案在ssd_mobilenet_train_logs中,如:

    • graph.pbtxt
    • model.ckpt-200000.data-00000-of-00001
    • model.ckpt-200000.info
    • model.ckpt-200000.meta

    其中meta儲存了graph和metadata,ckpt儲存了網路的weights。

    而進行預測時只需模型和權重,不需要metadata,故可使用官方提供的指令碼生成推導圖。

    python object_detection/export_inference_graph.py \
        --input_type image_tensor \
        --pipeline_config_path object_detection/VOC2012/ssd_mobilenet_v1_voc2012.config \
        --trained_checkpoint_prefix object_detection/VOC2012/ssd_mobilenet_train_logs/model.ckpt-200000
    \ --output_directory object_detection/VOC2012
  2. 測試圖片

    • 執行object_detection_tutorial.ipynb並修改其中的各種路徑即可。

    • 或自寫編譯inference指令碼,如tensorflow/models/object_detection/infer.py

      import sys
      sys.path.append('..')
      import os
      import time
      import tensorflow as tf
      import numpy as np
      from PIL import Image
      from matplotlib import
      pyplot as plt from utils import label_map_util from utils import visualization_utils as vis_util PATH_TEST_IMAGE = sys.argv[1] PATH_TO_CKPT = 'VOC2012/frozen_inference_graph.pb' PATH_TO_LABELS = 'VOC2012/pascal_label_map.pbtxt' NUM_CLASSES = 21 IMAGE_SIZE = (18, 12) label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories( label_map, max_num_classes=NUM_CLASSES, use_display_name=True
      ) category_index = label_map_util.create_category_index(categories) detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='') config = tf.ConfigProto() config.gpu_options.allow_growth = True with detection_graph.as_default(): with tf.Session(graph=detection_graph, config=config) as sess: start_time = time.time() print(time.ctime()) image = Image.open(PATH_TEST_IMAGE) image_np = np.array(image).astype(np.uint8) image_np_expanded = np.expand_dims(image_np, axis=0) image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') boxes = detection_graph.get_tensor_by_name('detection_boxes:0') scores = detection_graph.get_tensor_by_name('detection_scores:0') classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict={image_tensor: image_np_expanded}) print('{} elapsed time: {:.3f}s'.format(time.ctime(), time.time() - start_time)) vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates=True, line_thickness=8) plt.figure(figsize=IMAGE_SIZE) plt.imshow(image_np)

      執行infer.py test_images/image1.jpg即可