1. 程式人生 > >【CV】如何使用Tensorflow提供的Object Detection API--4--開始訓練模型

【CV】如何使用Tensorflow提供的Object Detection API--4--開始訓練模型

至此已經學習瞭如何選擇預訓練模型,將資料集轉為TFRecord格式。模型和資料都準備好了,是時候開啟訓練了。

這些在COCO資料集上的模型都是針對90類進行識別的,如果自己的任務沒有這麼多類,或者類不同怎麼辦呢?

如果是我們不是用物體檢測的API的話,答案是移除最後的90個類的分類器層,替換為一個新的神經網路層

shape = (fc_2nd_last_get_shape().as_list()[-1],nb_classes)
fc_last_W = tf.Variable(tf.truncated_normal(shape, stddev=1e-2))
fc_last_b = tf.Variable(
tf.zeros(nb_classes)) logits = tf.nn.xw_plus_b(fc_2nd_last, fc_last_W, fc_last_b)

但是對於物體檢測的API而言,我們只需要修改一下配置檔案即可。

object_detection/samples/configs資料夾下,有各種預訓練模型的配置檔案

以Faster-RCNN舉例子:

# Faster R-CNN with Inception Resnet v2, Atrous version, with Cosine
# Learning Rate schedule.
# Trained on COCO, initialized from Imagenet classification checkpoint
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.

——TBD—

參考: