def __init__(self, voc_root, transforms, train_set=True):-》voc_root训练集所在根目录,transforms预处理方法,train_set boolean变量
self.root = os.path.join(voc_root, "VOCdevkit", "VOC2012")
self.img_root = os.path.join(self.root, "JPEGImages")-》图像根目录
self.annotations_root = os.path.join(self.root, "Annotations")-》标注信息根目录
if train_set:
txt_list = os.path.join(self.root, "ImageSets", "Main", "train.txt")-》阅读train.txt文件
else:
txt_list = os.path.join(self.root, "ImageSets", "Main", "val.txt")-》阅读var .txt文件
with open(txt_list) as read:
self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")-》打开txt文件读取它每一行保存为xml文件
for line in read.readlines()]
# read class_indict
try:
json_file = open('./pascal_voc_classes.json', 'r')-》载入写有分类名称和索引的jason文件
self.class_dict = json.load(json_file)-》加入到class_dict 这个变量当中
except Exception as e:
print(e)
exit(-1)
self.transforms = transforms
def __len__(self):
return len(self.xml_list)-》返回数据集文件的个数
def __getitem__(self, idx):-》idx为索引值
# read xml
xml_path = self.xml_list[idx]-》获取xml文件的路径
with open(xml_path) as fid:-》打开xml文件
xml_str = fid.read()
xml = etree.fromstring(xml_str)-》读取xml文件
data = self.parse_xml_to_dict(xml)["annotation"]-》再将xml文件信息传入到parse_xml_to_dict(xml文件信息转化为字典)方法中
img_path = os.path.join(self.img_root, data["filename"])-》拼接成图像路径
image = Image.open(img_path)-》打开图片路径
if image.format != "JPEG":
raise ValueError("Image format not JPEG")-》如果不是jepg格式报错
boxes = []
labels = []
iscrowd = []