#
#作者:韦访
#博客:https://blog.youkuaiyun.com/rookie_wei
#微信:1007895847
#添加微信的备注一下是优快云的
#欢迎大家一起学习
#
1、概述
接着分析一下Object Detection API的源码,请结合前面的三篇关于Object Detection API的博客一起看,链接如下:
https://blog.youkuaiyun.com/rookie_wei/article/details/81143814
https://blog.youkuaiyun.com/rookie_wei/article/details/81210499
https://blog.youkuaiyun.com/rookie_wei/article/details/81275663
从最后一篇,训练自己的模型开始分析。
2、将VOC2012数据集转成tfrecord格式源码分析
根据命令,
python dataset_tools/create_pascal_tf_record.py --data_dir=my_images/VOCdevkit/ --year=VOC2012 --output_path=my_images/VOCdevkit/pascal_train.record --set=train
我们从dataset_tools/create_pascal_tf_record.py文件入手,看看怎么将图片数据转成tfrecord格式。从main函数开始,
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
if FLAGS.year not in YEARS:
raise ValueError('year must be in : {}'.format(YEARS))
看来set和year参数非设置不可,不过默认设置成train和VOC2007了。看看其他参数,定义如下,
flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('annotations_dir', 'Annotations',
'(Relative) path to annotations directory.')
flags.DEFINE_string('year', 'VOC2007', 'Desired challenge year.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/pascal_label_map.pbtxt',
'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
FLAGS = flags.FLAGS
SETS = ['train', 'val', 'trainval', 'test']
YEARS = ['VOC2007', 'VOC2012', 'merged']
比较简单不解释了,接着往下看,
data_dir = FLAGS.data_dir
years = ['VOC2007', 'VOC2012']
if FLAGS.year != 'merged':
years = [FLAGS.year]
#用于保存TFRecord文件
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
#解析label文件(data/pascal_label_map.pbtxt),结果如下,
#{'pottedplant': 16, 'diningtable': 11, 'sheep': 17, 'aeroplane': 1, 'bicycle': 2,
# 'person': 15, 'bus': 6, 'train': 19, 'sofa': 18, 'car': 7, 'chair': 9, 'dog': 12,
# 'bottle': 5, 'bird': 3, 'motorbike': 14, 'cow': 10, 'tvmonitor': 20, 'horse': 13,
# 'boat': 4, 'cat': 8}
label_map_dict = label_map_util.get_label_map_dict(FLAGS.label_map_path)
上面其实就是解析一下data/pascal_label_map.pbtxt文件,这个文件的格式如下,
item {
id: 1
name: 'aeroplane'
}
item {
id: 2
name: 'bicycle'
}
item {
id: 3
name: 'bird'
}
item {
id: 4
name: 'boat'
}
item {
id: 5
name: 'bottle'
}
item {
id: 6
name: 'bus'
}
item {
id: 7
name: 'car'
}
item {
id: 8
name: 'cat'
}
item {
id: 9
name: 'chair'
}
item {
id: 10
name: 'cow'
}
item {
id: 11
name: 'diningtable'
}
item {
id: 12
name: 'dog'
}
item {
id: 13
name: 'horse'
}
item {
id: 14
name: 'motorbike'
}
item {
id: 15
name: 'person'
}
item {
id: 16
name: 'pottedplant'
}
item {
id: 17
name: 'sheep'
}
item {
id: 18
name: 'sofa'
}
item {
id: 19
name: 'train'
}
item {
id: 20
name: 'tvmonitor'
}
这就是我们要识别的20个种类。继续,
for year in years:
logging.info('Reading from PASCAL %s dataset.', year)
# data_dir/VOC2012/ImageSets/Main/aeroplane_train.txt
examples_path = os.path.join(data_dir, year, 'ImageSets', 'Main',
'aeroplane_' + FLAGS.set + '.txt')
# data_dir/VOC2012/Annotations
annotations_dir = os.path.join(data_dir, year, FLAGS.annotations_dir)
# 将aeroplane_train.txt文件的每行第一个字段放到数组里,其实就是不带后缀的文件名
examples_list = dataset_util.read_examples_list(examples_path)
print(len(examples_list))
上面因为我们传入的FLAGS.year参数是VOC2012,所以这里的year也就是’VOC2012’。所以examples_path就是”data_dir/VOC2012/ImageSets/Main/aeroplane_train.txt”文件,annotations_dir就是”data_dir/VOC2012/Annotations”文件夹。
dataset_util.read_examples_list函数就是读取aeroplane_train.txt文件,然后,将每一行的第一个字符串保存到数组里。这些字符串其实就是data_dir/VOC2012/Annotations文件夹下对应的文件名,只是不带后缀而已,aeroplane_train.txt文件内容如下,
data_dir/VOC2012/Annotations文件夹下的文件如下,
注意的是,aeroplane_train.txt文件内容并没有包含所有的data_dir/VOC2012/Annotations文件夹下的文件。接着往下看,
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
#获取 data_dir/VOC2012/Annotations 文件夹下对应的 20xx_xxxxxx.xml文件
path = os.path.join(annotations_dir, example + '.xml')
#打开上面获取的xml的文件,解析它
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
#解析xml文件
xml = etree.fromstring(xml_str)
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, FLAGS.data_dir, label_map_dict,
FLAGS.ignore_difficult_instances)
#保存成record
writer.write(tf_example.SerializeToString())
上面就是解析aeroplane_train.txt文件里所有的对应的data_dir/VOC2012/Annotations文件夹下的文件。我们主要看看dict_to_tf_example函数,这个函数的代码要结合data_dir/VOC2012/Annotations文件夹下的xml文件的格式来看,就很容易理解了,格式如下,
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
"""Convert XML derived dict to tf.Example proto.
Notice that this function normalizes the bounding box coordinates provided
by the raw data.
Args:
data: dict holding PASCAL XML fields for a single image (obtained by
running dataset_util.recursive_parse_xml_to_dict)
dataset_directory: Path to root directory holding PASCAL dataset
label_map_dict: A map from string label names to integers ids.
ignore_difficult_instances: Whether to skip difficult instances in the
dataset (default: False).
image_subdirectory: String specifying subdirectory within the
PASCAL dataset directory holding the actual image data.
Returns:
example: The converted tf.Example.
Raises:
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
"""
#获取该xml对应的图片的文件, 比如 VOC2012/JPEGImages/2008_000008.jpg
img_path = os.path.join(data['folder'], image_subdirectory, data['filename'])
# 再加上 dataset_directory 的路径, 比如 my_images/VOCdevkit/VOC2012/JPEGImages/2008_000008.jpg
full_path = os.path.join(dataset_directory, img_path)
首先,根据folder和filename关键字找到该xml文件对应的图片的路径。
接着往下看,
#读取图片
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
#图片哈希值
key = hashlib.sha256(encoded_jpg).hexdigest()
#获取图片的宽和高
width = int(data['size']['width'])
height = int(data['size']['height'])
读取图片,生成哈希值,获取图片的宽高,接着往下看,
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
if 'object' in data:
for obj in data['object']:
#目标是否难以检测
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
#获取检测框,左下角和右上角坐标
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
#目标名称
classes_text.append(obj['name'].encode('utf8'))
#对应于label文件夹里的种类的数字
classes.append(label_map_dict[obj['name']])
#目标有没有被遮挡
truncated.append(int(obj['truncated']))
#pose
poses.append(obj['pose'].encode('utf8'))
也是解析xml文件的一些字段,接着看,
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
将上面解析的结果,传到tf.train.Example里,这里看不懂的话,可以看看下面的博客:
https://www.jianshu.com/p/b480e5fcb638
最后,转成tfrecord以后,记得关闭writer。
writer.close()
这部分比较简单,对着源码看看就可以看明白了。那里不明白就print看看。
如果您感觉本篇博客对您有帮助,请打开支付宝,领个红包支持一下,祝您扫到99元,谢谢~~