1. 程式人生 > >Tensorflow學習——結合ROS呼叫模型實現目標識別

Tensorflow學習——結合ROS呼叫模型實現目標識別

環境:Ubuntu16.04+Tensorflow-cpu-1.6.0+ROS Kinetic+OpenCV3.3.1

前期準備:

  1. 完成Object Detection api配置
  2. 完成OpenCV配置

完成模型訓練後就是模型的應用,這裡通過ROS利用Object Detection api呼叫模型實現目標物體的識別。

一、模型匯入

模型路徑設定如下圖所示,注意設定目標物件型別數目。

		#Get models
		rospy.loginfo("begin initialization...")
		self.PATH_TO_CKPT = '../frozen_inference_graph.pb'
		self.PATH_TO_LABELS = '../bottel.pbtxt'
		self.NUM_CLASSES = 2
		self.detection_graph = self._load_model()
		self.category_index = self._load_label_map()
		self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
		self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
		self.scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
		self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
		self.num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')

二、資料處理    

呼叫模型識別目標物件前需進行資料處理,流程如下圖所示。

  1. 相機獲取的影象資訊會以ROSImage Message的格式釋出在ROS平臺上,然後通過CvBridge對獲取的影象資訊進行轉換,將其從ROSImage Message格式轉變為Mat格式。
  2. 通過OpenCV對獲取影象資料進行預處理後轉換為numpy陣列,然後呼叫ObjectDetection API進行識別。
  3. 完成影象中目標物體的識別後,識別結果以陣列的形式釋出到相關話題中,同時視覺識別程式會將識別出來的目標物體使用帶有顏色的矩形框出來並在其上方標識識別物體的標籤及其概率,然後在轉換為ROSImage Message格式釋出到相應話題中。


程式碼實現

	# detect object from the image		
	def imgprogress(self, image_msg):
		with self.detection_graph.as_default():
			with tf.Session(graph=self.detection_graph) as sess:
				#translate image_msg data
				cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "rgb8")
				pil_img = Image.fromarray(cv_image)
				(im_width, im_height) = pil_img.size
				image_np =np.array(pil_img.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
				# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
				image_np_expanded = np.expand_dims(image_np, axis=0)

				# Actual detection.
				(boxes, scores, classes, num_detections) = sess.run([self.boxes, self.scores, self.classes, self.num_detections],feed_dict={self.image_tensor: image_np_expanded})
				
				# Visualization of the results of a detection.
				vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),
				self.category_index,
    		    use_normalized_coordinates=True,
     		 	line_thickness=8)
				
				#public img_msg
				ROSImage_pro=self._cv_bridge.cv2_to_imgmsg(image_np,encoding="rgb8")
				self._pub.publish(ROSImage_pro)
	

三、觸發識別

         因通過Object Detection API進行物體識別需要佔用大量資源,所以採用動態識別的會非常卡,這裡採用觸發器進行觸發識別,本程式設定了一個訂閱器self._sub用於獲取用於識別的圖片,當需要進行識別時,釋出圖片到image_topic即可觸發程式,同時結果會通過self._pub釋出到object_detection話題中。

		# Subscribe to judge
		self._sub = rospy.Subscriber(image_topic, ROSImage, self.imgprogress, queue_size=10)
		 
		# Subscribe to the image
		self._pub = rospy.Publisher('object_detection', ROSImage, queue_size=1)

完整程式

#!/usr/bin/env python

import rospy
from sensor_msgs.msg import Image as ROSImage
from cv_bridge import CvBridge
import cv2
import matplotlib
import numpy as np
import os
import six.moves.urllib as urllib
import sys
import tarfile
import tensorflow as tf
import zipfile
import uuid
from collections import defaultdict
from io import StringIO
from PIL import Image
from math import isnan

# This is needed since the notebook is stored in the object_detection folder.
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util

class ObjectDetectionDemo():
	def __init__(self):
		rospy.init_node('tfobject')

	    # Set the shutdown function (stop the robot)
		rospy.on_shutdown(self.shutdown)
		camera_topic = "/camera/rgb/image_raw" #rospy.get_param("~image_topic", "")
		image_topic = "/image/rgb/object"

		self.vfc=0
		self._cv_bridge = CvBridge()

		#Get models
		rospy.loginfo("begin initialization...")
		self.PATH_TO_CKPT = '../frozen_inference_graph.pb'
		self.PATH_TO_LABELS = '../bottel.pbtxt'
		self.NUM_CLASSES = 2
		self.detection_graph = self._load_model()
		self.category_index = self._load_label_map()
		self.image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
		self.boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
		self.scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
		self.classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
		self.num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')

		# Subscribe to judge
		self._sub = rospy.Subscriber(image_topic, ROSImage, self.imgprogress, queue_size=10)
		 
		# Subscribe to the image
		self._pub = rospy.Publisher('object_detection', ROSImage, queue_size=1)
		rospy.loginfo("initialization has finished...")
	
	def _load_model(self):
		detection_graph = tf.Graph()
		with detection_graph.as_default():
			od_graph_def = tf.GraphDef()
			with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
				serialized_graph = fid.read()
				od_graph_def.ParseFromString(serialized_graph)
				tf.import_graph_def(od_graph_def, name='')
		return detection_graph
	
	def _load_label_map(self):
		label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
		categories = label_map_util.convert_label_map_to_categories(label_map,max_num_classes=self.NUM_CLASSES,use_display_name=True)
		category_index = label_map_util.create_category_index(categories)
		return category_index
	
	# detect object from the image		
	def imgprogress(self, image_msg):
		with self.detection_graph.as_default():
			with tf.Session(graph=self.detection_graph) as sess:
				#translate image_msg data
				cv_image = self._cv_bridge.imgmsg_to_cv2(image_msg, "rgb8")
				pil_img = Image.fromarray(cv_image)
				(im_width, im_height) = pil_img.size
				image_np =np.array(pil_img.getdata()).reshape((im_height, im_width, 3)).astype(np.uint8)
				# Expand dimensions since the model expects images to have shape: [1, None, None, 3]
				image_np_expanded = np.expand_dims(image_np, axis=0)

				# Actual detection.
				(boxes, scores, classes, num_detections) = sess.run([self.boxes, self.scores, self.classes, self.num_detections],feed_dict={self.image_tensor: image_np_expanded})
				
				# Visualization of the results of a detection.
				vis_util.visualize_boxes_and_labels_on_image_array(image_np,np.squeeze(boxes),np.squeeze(classes).astype(np.int32),np.squeeze(scores),
				self.category_index,
    		    use_normalized_coordinates=True,
     		 	line_thickness=8)
				
				#public img_msg
				ROSImage_pro=self._cv_bridge.cv2_to_imgmsg(image_np,encoding="rgb8")
				self._pub.publish(ROSImage_pro)
	
	# stop node
	def shutdown(self):
		rospy.loginfo("Stopping the tensorflow object detection...")
		rospy.sleep(1) 
	
if __name__ == '__main__':
    try:
        ObjectDetectionDemo()
        rospy.spin()
    except rospy.ROSInterruptException:
        rospy.loginfo("RosTensorFlow_ObjectDetectionDemo has started.")