用Tensorflow Object Detection API 訓練自己的資料集
阿新 • • 發佈:2019-01-07
一、準備資料集
Tensorflow Object Detection API 用 TFRecord 檔案格式讀取資料,需把 VOC 格式的資料集進行轉換(我自己的資料集是VOC2007)
1、修改 tensorflow/models/object_detection/create_pascal_tf_record.py 檔案第84行和162行。
2、修改tensorflow/models/object_detection/data/pascal_label_map.pbtxt 檔案裡的類別.
3、執行命令:
# From tensorflow/models
python object_detection/create_pascal_tf_record.py \
--label_map_path=object_detection/data/pascal_label_map.pbtxt \
--data_dir=VOCdevkit --year=VOC2007 --set=train \
--output_path=pascal_train.record
python object_detection/create_pascal_tf_record.py \
--label_map_path=object_detection/data/pascal_label_map.pbtxt \
--data_dir=VOCdevkit --year=VOC2007 --set=val \
--output_path=pascal_val.record
執行後會在object_detection資料夾下生成pascal_train.record和pascal_val.record兩個檔案。
二、下載預訓練模型
tar -xzvf ssd_mobilenet_v1_coco.tar.gz
三、修改配置檔案
修改 object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_pets.config檔案:
(1)num_classes:修改為自己的classes num
(2)將所有PATH_TO_BE_CONFIGURED的地方修改為自己之前設定的路徑(共5處)
四、訓練
進入object_detection目錄,執行:
tensorflow/models$ python object_detection/train.py --train_dir='/home/anngic/tensorflow/train' --pipeline_config_path='/home/anngic/tensorflow/models/object_detection/samples/configs/faster_rcnn_inception_resnet_v2_atrous_coco.config'
五、tensorboad
輸入命令:
tensorboard --logdir=/home/shz/TF-OD-Test/train