faster rcnn 源码解读(一)

上一篇介绍了faster rcnn 的原理,现在开始解读源码,这里因为我看得是keras版的,所以用keras版的来讲解。

源码的网址:keras版faster rcnn网络

第一讲  读取文件(pascal_voc_parser.py)

    这个文件相对比较简单,就是从pascal里面读取数据,解析数据,具体我们看下面的代码:

下面是导入包,其中ET是解析xml文件的

import os
import cv2
import xml.etree.ElementTree as ET
import numpy as np

以下是定义了一个函数,函数的输入是数据的路径

def get_data(input_path):
	all_imgs = []

	classes_count = {}

	class_mapping = {}

	visualise = False

上面的class_count 是记录每一个类的数量class_mapping 代表的是类的映射,就是把对映射为对应的编号,这样后面好处理,比如后面使用softmax那么就需要one-hot编码。visualise 是是否可视化。

data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']

这句话是其实就是一个路径的拼接过程,

	for data_path in data_paths:

		annot_path = os.path.join(data_path, 'Annotations')
		imgs_path = os.path.join(data_path, 'JPEGImages')
		imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')
		imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')

		trainval_files = []
		test_files = []

上面这些代码就是从data_path里面读取数据,因为之前是voc2007和2011,这就是一个遍历过程,上面也是一个拼接路径的过程,然后存放在lsit中。

		for annot in annots:
			try:
				idx += 1
				et = ET.parse(annot)
				element = et.getroot()

				element_objs = element.findall('object')
				element_filename = element.find('filename').text
				element_width = int(element.find('size').find('width').text)
				element_height = int(element.find('size').find('height').text)

				if len(element_objs) > 0:
					annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
									   'height': element_height, 'bboxes': []}

					if element_filename in trainval_files:
						annotation_data['imageset'] = 'trainval'
					elif element_filename in test_files:
						annotation_data['imageset'] = 'test'
					else:
						annotation_data['imageset'] = 'trainval'

				for element_obj in element_objs:
					class_name = element_obj.find('name').text
					if class_name not in classes_count:
						classes_count[class_name] = 1
					else:
						classes_count[class_name] += 1

					if class_name not in class_mapping:
						class_mapping[class_name] = len(class_mapping)

					obj_bbox = element_obj.find('bndbox')
					x1 = int(round(float(obj_bbox.find('xmin').text)))
					y1 = int(round(float(obj_bbox.find('ymin').text)))
					x2 = int(round(float(obj_bbox.find('xmax').text)))
					y2 = int(round(float(obj_bbox.find('ymax').text)))
					difficulty = int(element_obj.find('difficult').text) == 1
					annotation_data['bboxes'].append(
						{'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
				all_imgs.append(annotation_data)

以上代码就是开始读取数据,annots里面是xml文件,xml里面包括目标框的信息,以及目标的类别,以上这段代码就是从xml解析出坐标信息,类别,以及文件名等,然后这些信息都存放在all_imgs里面,all_imgs是一个list,但是他的每个元素都是字典,也就是存放的是每一张图片里面的信息,比如某张图像的总共有多少个目标,有多少个类等。

return all_imgs, classes_count, class_mapping
以上是这个函数的返回值,all_imgs 是一个字典,class_count 是各类总共的数量,是一个字典,key是类名,value是该类的数量,class_mapping 是每个类对应的标签。

总结起来,pascal_voc_parser.py就是解析文件,读取文件




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值