1. 程式人生 > >在谷歌目標檢測(Google object_detection) API 上訓練自己的資料集

在谷歌目標檢測(Google object_detection) API 上訓練自己的資料集

知乎連結:https://zhuanlan.zhihu.com/p/28218410

應公司要求,利用谷歌最近開源的Google object_detection API對公司收集的資料集進行訓練,並檢測訓練效果。通過一兩天的研究以及維持四天的訓練(GTX 1060  6GB),終於成功的在自己資料集上訓練的任務。測試效果感覺還行,雖沒有達到谷歌官方公佈的資料集上跑的識別效果,但是識別率也還過得去,這主要是因為資料集沒有官方做的那麼規範。下圖為本人挑選的一張識別率較好的圖片(識別哈爾濱啤酒):


下面把本人如何一步步在自己的資料集上訓練的詳細步驟做個總結,一是方便自己以後操作起來更快的再次上手訓練,二是方便大家能好的實現該API的一些需求。

需要說明的:

2:資料集格式需要為轉換成tensorflow要求的tfrecord的形式。

3:本文在GTX 1060  6GB的顯示卡上訓練了四天

4:如何安裝tensorflow等一些依賴庫,本文不再贅述,請參考:安裝依賴庫教程連結

過程:

1:下載Google object_detection API

下載地址

2:資料集準備:

資料集需要符合API所需的TFRecord格式,官方提供的資料集格式為PASCAL VOC格式,API已經為我們提供了將此格式轉為TFRecord的程式碼. 但是這裡我們需要注意一個細節:create_pascal_tf_record.py中的

examples_path = os.path.join(data_dir, 
year, 'ImageSets', 'Main', 'aeroplane_' + FLAGS.set + '.txt')
去掉'aeroplane_'。

同時,將檔案中的

flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
'Path to label map proto')

data/pascal_labe_map.pbtxt改為自己的資料集label

然後在“tensorflow/models/object_detection/”目錄下運行以下命令

#生成訓練集record
python create_pascal_tf_record.py --data_dir=`自己的訓練資料集路徑` \
    --year=VOC2007 --set=train --output_path=`你想儲存的訓練集的record路徑`

#生成驗證集record
 python create_pascal_tf_record.py --data_dir=`自己的驗證資料集路徑` \

    --year=VOC2007 --set=val --output_path=`你想儲存的驗證集的record路徑`

注意,在data目錄下選擇一個.pbtxt檔案,將該檔案改為自己資料集的label。
執行上述兩個命令後會在data資料夾下生成兩個record檔案。 

3:下載預訓練模型

按照上文“需要說明的”第一條下載預訓練模型,將下載好的模型進行解壓,並將.ckpt的三個檔案拷貝到models目錄下。將object_detection/samples/configs/faster_rcnn_resnet101_voc07.config複製到models目錄下並做如下修改:
1)num_classes:修改為之前修的的.pbtxt檔案中的類別數目
2)將所有'PATH_TO_BE_CONFIGURED'修改為自己之前設定的路徑

4:開始訓練

執行上述三步之後我們可以開始訓練了,此處需要注意兩點,不然會出現模組匯出錯誤,在tensorflow/models分別執行:

protoc object_detection/protos/*.proto --python_out=.

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

然後進入到obeject_detection目錄下,執行一下命令:

python train.py --train_dir='想要儲存訓練模型的路徑' --pipeline_config_path='你採用的.config檔案路徑'

5:模型視覺化

執行上述四步之後您基本上只需等著模型執行完成即可,如果您想要視覺化您的模型,可以執行:

tensorboard --logdir=’上面第4點提到的train_dir路徑‘

然後在你的瀏覽器輸入0.0.0.0:6006就能看到模型一些相關引數的視覺化結果了。

訓練完成後會生成三個.cpkt的檔案,將這三個檔案複製到tensorflow/models下,可利用這三個檔案生成一個.pb檔案,生成程式碼如下:

python object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path ’你的.config檔案路徑’ \
    --checkpoint_path model.ckpt-‘CHECKPOINT_NUMBER’ \
    --inference_graph_path output_inference_graph.pb
這樣你就可以利用.pb檔案進行目標檢測了,具體步驟請參考:github.com/tensorflow/m

6:參考

如您覺得本文對你有幫助,請酌情讚賞。同時本文如有不完善的地方歡迎指正!謝謝!