Object Detection API谷歌
该文章部分参考别的大佬的,由于忘了内容出处,所以没有加转载链接,请谅解,有原创作者看到可以联系我添加。
========转载请注明出处==========
此python文件放在dataset_tools下面
生成自己训练的数据集主要看个人annotation文件是什么格式的。我这里的每张图都有自己的annotation文件,例如:
图片xxx.jpg,其annotation文件为xxx.box
box文件内容为:
Xmin Ymin Xmax Ymax label 如下图:如果有多个label ,可以继续追加在下一行:
Xmin Ymin Xmax Ymax label \n
Xmin Ymin Xmax Ymax label
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import os
import PIL.Image
import tensorflow as tf
import pandas as pd
import cv2
from functools import reduce
import operator
from object_detection.utils import dataset_util
flags = tf.app.flags
flags.DEFINE_string('train_imgs_dir', '/home/ai/Downloads/competition_change_box_img/img', 'Root directory to bc train dataset.')
flags.DEFINE_string('train_labels', '/home/ai/Downloads/competition_change_box_img/box',
'(Relative) path to annotations directory.')
flags.DEFINE_string('train_output', '../All_tf_record/competition_img_test.record', 'Path to output TFRecord')
FLAGS = flags.FLAGS
def create_coordinate_info_of_content_list(image_dir,label_dir):
content_list_all = []
for item,file_name in enumerate(os.listdir(label_dir)):
img = cv2.imread(os.path.join(image_dir,file_name.replace('.box','.jpg')))
height = img.shape[0]
width = img.shape[1]
deepth = img.shape[2]
content_list = [[file_name.replace('.box', '.jpg'), height, width, deepth]]
with open(os.path.join(label_dir,file_name), 'r') as f: lines = f.readlines()
for line in lines:
new_line = line.split(' ')[:]
content_one = [new_line[0],new_line[1],new_line[2],new_line[3],new_line[4]]
content_list.append(content_one)
a = reduce(operator.add,content_list)
content_list_all.append(a)
return content_list_all
def create_tf_example(content_list, imgs_dir):
height = int(content_list[1])
width = int(content_list[2])
filename = content_list[0]
img_path = os.path.join(imgs_dir, filename)
with tf.gfile.GFile(img_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
box_num = int((len(content_list) - 4) / 5) #一张图上可能有多个label
for i in range(box_num):
xmin.append(float(content_list[5 * i + 4 + 0]) / width)
ymin.append(float(content_list[5 * i + 4 + 1]) / height)
xmax.append(float(content_list[5 * i + 4 + 2]) / width)
ymax.append(float(content_list[5 * i + 4 + 3]) / height)
classes_text.append(content_list[5 * i + 4 + 4].encode('utf8'))
classes.append(classMap[content_list[5 * i + 4 + 4]])
print('the class id is {} '.format(classMap[content_list[5 * i + 4 + 4]]))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
filename.encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
filename.encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return example
def main(_):
# train tfrecord generate
print("Reading from {}".format(FLAGS.train_imgs_dir))
writer = tf.python_io.TFRecordWriter(FLAGS.train_output)
content_list_all = create_coordinate_info_of_content_list(FLAGS.train_imgs_dir, FLAGS.train_labels)
for line in content_list_all:
content_list = line
tf_example = create_tf_example(content_list, FLAGS.train_imgs_dir)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.app.run()