1. 程式人生 > >使用Tensorflow物體識別API摳出視訊中的豬

使用Tensorflow物體識別API摳出視訊中的豬

豬檢測程式碼以及後續進行豬分類的程式都開源在github了。

主要在官方的demo code上做了如下修改:

  1. 擴充套件det出的box,以更好地包裹目標,crop時限定不超出影象邊界[expand_ratio]
  2. 如檢測出pig, animal可能都是對的,可以依據執行結果調整接受規則,抑制檢測到的概率比較大的無關類別,提高魯棒性[class_keep]
  3. 使用mini batch的方式,以充分利用GPU提高程式執行效率。

下面重點看一下與obj det API有關的核心程式碼:

# Load a (frozen) Tensorflow model into memory
'''
tf.GraphDef():
The GraphDef class is an object created by the ProtoBuf. 
詳見https://www.tensorflow.org/extend/tool developers/
graph_def: 
A GraphDef proto containing operations to be imported into the default graph
'''
detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name='')
'''這裡用了幾個util函式。
'''
label_map = label_map_util.load_labelmap(PATH_TO_LABELS)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True
) category_index = label_map_util.create_category_index(categories)
'''
重點看定義計算圖。
在這個指令碼中圖片是通過feed_dict={image_tensor: image_np_expanded})傳遞給計算圖的。之前的博文介紹過如何使用自己生成的tfrecord,另外還可以使用tf1.4新出的dataset API。
關於get_tensor_by_name,就是通過名字來獲得張量,具體見下面一段小測試程式碼。
但是還是看不出來為什麼這個計算圖能work,看起來就是獲取了幾個張量,應該就是檢測框等張量依賴於image_tensor,我們去原始碼裡確認一下。發現在object_detection/inference/detection_inference.py檔案中build_inference_graph函式裡,這個函式主要作用是Loads the inference graph and connects it to the input image.
具體如下:
  tf.import_graph_def(
      graph_def, name='', input_map={'image_tensor': image_tensor})
官方文件:input_map: A dictionary mapping input names (as strings) in graph_def to Tensor objects. The values of the named input tensors in the imported graph will be re-mapped to the respective Tensor values.
再來看看build_inference_graph函式是在哪被呼叫的。然後發現確實在inference資料夾下被呼叫了,但是我們這裡通過feed的方式並不是呼叫這個函式。猜想一定是匯出網路時定義了image_tensor這個變數名,如在object_detection/exporter.py可以看到image_tensor是placeholder,意料之中。至於計算圖具體的連線關係就是模型定義本身了,下次分析訓練的程式碼再看。
'''
with detection_graph.as_default(): with tf.Session(graph=detection_graph) as sess: # Definite input and output Tensors for detection_graph image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') # Each box represents a part of the image where a particular object was detected. detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') # Each score represent how level of confidence for each of the objects. # Score is shown on the result image, together with the class label. detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') num_detections = detection_graph.get_tensor_by_name('num_detections:0') '''注意此處省略了一些程式碼''' (boxes, scores, classes, num) = sess.run( [detection_boxes, detection_scores, detection_classes, num_detections], feed_dict={image_tensor: image_np_expanded})
import tensorflow as tf

c = tf.constant([[1.0, 2.0], [3.0, 4.0]])
d = tf.constant([[1.0, 1.0], [0.0, 1.0]])
e = tf.matmul(c, d, name='example')

with tf.Session() as sess:
    test =  sess.run(e)
    print (e.name) #example:0
    print(test)
    test = tf.get_default_graph().get_tensor_by_name("example:0")
    print (test) #Tensor("example:0", shape=(2, 2), dtype=float32)
    print (test.eval())
'''
輸出是:
example_2:0
[[ 1.  3.]
 [ 3.  7.]]
Tensor("example:0", shape=(2, 2), dtype=float32)
[[ 1.  3.]
 [ 3.  7.]]
'''