1. 程式人生 > >Tensorflow Object Detection 生成自己的tfrecord訓練資料集

Tensorflow Object Detection 生成自己的tfrecord訓練資料集

Object Detection API谷歌

該文章部分參考別的大佬的,由於忘了內容出處,所以沒有加轉載連結,請諒解,有原創作者看到可以聯絡我新增。

========轉載請註明出處==========

此python檔案放在dataset_tools下面

生成自己訓練的資料集主要看個人annotation檔案是什麼格式的。我這裡的每張圖都有自己的annotation檔案,例如:

圖片xxx.jpg,其annotation檔案為xxx.box

box檔案內容為:

Xmin Ymin Xmax Ymax  label  如下圖:如果有多個label ,可以繼續追加在下一行:

Xmin Ymin Xmax Ymax  label \n

Xmin Ymin Xmax Ymax  label

 

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import os
import PIL.Image
import tensorflow as tf
import pandas as pd
import cv2
from functools import reduce
import operator
from object_detection.utils import dataset_util

flags = tf.app.flags
flags.DEFINE_string('train_imgs_dir', '/home/ai/Downloads/competition_change_box_img/img', 'Root directory to bc train dataset.')
flags.DEFINE_string('train_labels', '/home/ai/Downloads/competition_change_box_img/box',
                    '(Relative) path to annotations directory.')
flags.DEFINE_string('train_output', '../All_tf_record/competition_img_test.record', 'Path to output TFRecord')
FLAGS = flags.FLAGS



def create_coordinate_info_of_content_list(image_dir,label_dir):
    content_list_all = []
    for item,file_name in enumerate(os.listdir(label_dir)):
        img = cv2.imread(os.path.join(image_dir,file_name.replace('.box','.jpg')))
        height = img.shape[0]
        width = img.shape[1]
        deepth = img.shape[2]
        content_list = [[file_name.replace('.box', '.jpg'), height, width, deepth]]
        with open(os.path.join(label_dir,file_name), 'r') as f: lines = f.readlines()
        for line in lines:
            new_line = line.split(' ')[:]
            content_one = [new_line[0],new_line[1],new_line[2],new_line[3],new_line[4]]
            content_list.append(content_one)
        a = reduce(operator.add,content_list)
        content_list_all.append(a)
   
    return content_list_all

def create_tf_example(content_list, imgs_dir):
    height = int(content_list[1])
    width = int(content_list[2])
    filename = content_list[0]
    img_path = os.path.join(imgs_dir, filename)
    with tf.gfile.GFile(img_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = PIL.Image.open(encoded_jpg_io)
    if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
    key = hashlib.sha256(encoded_jpg).hexdigest()

    xmin = []
    ymin = []
    xmax = []
    ymax = []
    classes = []
    classes_text = []
    box_num = int((len(content_list) - 4) / 5)    #一張圖上可能有多個label
    for i in range(box_num):
        xmin.append(float(content_list[5 * i + 4 + 0]) / width)
        ymin.append(float(content_list[5 * i + 4 + 1]) / height)
        xmax.append(float(content_list[5 * i + 4 + 2]) / width)
        ymax.append(float(content_list[5 * i + 4 + 3]) / height)
        classes_text.append(content_list[5 * i + 4 + 4].encode('utf8'))
        classes.append(classMap[content_list[5 * i + 4 + 4]])
        print('the class id is {} '.format(classMap[content_list[5 * i + 4 + 4]]))
    example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(
            filename.encode('utf8')),
        'image/source_id': dataset_util.bytes_feature(
            filename.encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return example

def main(_):
    # train tfrecord generate
    print("Reading from {}".format(FLAGS.train_imgs_dir))
    writer = tf.python_io.TFRecordWriter(FLAGS.train_output)
    content_list_all = create_coordinate_info_of_content_list(FLAGS.train_imgs_dir, FLAGS.train_labels)
    for line in content_list_all:
        content_list = line
        tf_example = create_tf_example(content_list, FLAGS.train_imgs_dir)
        writer.write(tf_example.SerializeToString())
    writer.close()

if __name__ == '__main__':
    tf.app.run()