上一篇介绍了faster rcnn 的原理,现在开始解读源码,这里因为我看得是keras版的,所以用keras版的来讲解。
源码的网址:keras版faster rcnn网络
第一讲 读取文件(pascal_voc_parser.py)
这个文件相对比较简单,就是从pascal里面读取数据,解析数据,具体我们看下面的代码:
下面是导入包,其中ET是解析xml文件的
import os
import cv2
import xml.etree.ElementTree as ET
import numpy as np
以下是定义了一个函数,函数的输入是数据的路径
def get_data(input_path):
all_imgs = []
classes_count = {}
class_mapping = {}
visualise = False
上面的class_count 是记录每一个类的数量class_mapping 代表的是类的映射,就是把对映射为对应的编号,这样后面好处理,比如后面使用softmax那么就需要one-hot编码。visualise 是是否可视化。
data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']
这句话是其实就是一个路径的拼接过程,
for data_path in data_paths:
annot_path = os.path.join(data_path, 'Annotations')
imgs_path = os.path.join(data_path, 'JPEGImages')
imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')
imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')
trainval_files = []
test_files = []
上面这些代码就是从data_path里面读取数据,因为之前是voc2007和2011,这就是一个遍历过程,上面也是一个拼接路径的过程,然后存放在lsit中。
for annot in annots:
try:
idx += 1
et = ET.parse(annot)
element = et.getroot()
element_objs = element.findall('object')
element_filename = element.find('filename').text
element_width = int(element.find('size').find('width').text)
element_height = int(element.find('size').find('height').text)
if len(element_objs) > 0:
annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
'height': element_height, 'bboxes': []}
if element_filename in trainval_files:
annotation_data['imageset'] = 'trainval'
elif element_filename in test_files:
annotation_data['imageset'] = 'test'
else:
annotation_data['imageset'] = 'trainval'
for element_obj in element_objs:
class_name = element_obj.find('name').text
if class_name not in classes_count:
classes_count[class_name] = 1
else:
classes_count[class_name] += 1
if class_name not in class_mapping:
class_mapping[class_name] = len(class_mapping)
obj_bbox = element_obj.find('bndbox')
x1 = int(round(float(obj_bbox.find('xmin').text)))
y1 = int(round(float(obj_bbox.find('ymin').text)))
x2 = int(round(float(obj_bbox.find('xmax').text)))
y2 = int(round(float(obj_bbox.find('ymax').text)))
difficulty = int(element_obj.find('difficult').text) == 1
annotation_data['bboxes'].append(
{'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
all_imgs.append(annotation_data)
以上代码就是开始读取数据,annots里面是xml文件,xml里面包括目标框的信息,以及目标的类别,以上这段代码就是从xml解析出坐标信息,类别,以及文件名等,然后这些信息都存放在all_imgs里面,all_imgs是一个list,但是他的每个元素都是字典,也就是存放的是每一张图片里面的信息,比如某张图像的总共有多少个目标,有多少个类等。
return all_imgs, classes_count, class_mapping
以上是这个函数的返回值,all_imgs 是一个字典,class_count 是各类总共的数量,是一个字典,key是类名,value是该类的数量,class_mapping 是每个类对应的标签。
总结起来,pascal_voc_parser.py就是解析文件,读取文件