本文的代码是参考b站up主霹雳吧啦Wz1.2Faster RCNN源码解析(pytorch)_哔哩哔哩_bilibili
最终目的是复现fasterrcnn网络
先附上整体代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2022/5/15 10:28
# @Author : 半岛铁盒
# @File : mydataset.py
# @Software: win10 python3.6
#定义一个自己的数据集
import random
from torch.utils.data import Dataset
import os
import json
from PIL import Image
from lxml import etree
import torch
import matplotlib.pyplot as plt
#继承自torch.utils.data.Dataset
class VOCDataSet(Dataset):
def __init__(self,voc_root,transforms,train_set=True):
self.root=os.path.join(voc_root,"VOCdevkit","VOC2007")
self.img_root=os.path.join(self.root,"JPEGImages")
self.Annotations_root=os.path.join(self.root,"Annotations")
#如果是训练集则训练集,否则验证集
if train_set==True:
txt_list=os.path.join(self.root,"ImageSets","Main","train.txt")
else:
txt_list=os.path.join(self.root,"ImageSets","Main","val.txt")
#readlines()一次性先将整个文件内容按行读取完
#strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列,此处用来去除换行符
#xml_list代表存放训练集或验证集的txt文件
with open(txt_list) as read:
self.xml_list=[os.path.join(self.Annotations_root,line.strip()+".xml") for line in read.readlines()]
json_file=open("./pascal_voc_classes.json","r")
self.class_dict=json.load(json_file)
self.transforms=transforms
# print(self.xml_list[154])
def __len__(self):
return len(self.xml_list)
def __getitem__(self, idx):
#返回索引对应的xml文件
xml_path=self.xml_list[idx]
with open(xml_path,encoding="UTF-8") as fid:
#读取xml文件
xml_str=fid.read()
#etree.fromstring()(将字符串转化为Element对象) 用etree来读取xml文件
xml=etree.fromstring(xml_str)
#parse_xml_to_dict返回标签对应的xml信息
data=self.parse_xml_to_dict(xml)["annotation"]
#如果filename不是jpg后缀就把它改成jpg后缀
file_name=data["filename"].split(".")[0]+".jpg"
ima_path=os.path.join(self.img_root,file_name)
image=Image.open(ima_path)
boxes=[]
labels=[]
iscrowd=[]
#将标签信息中objec里的bndbox内的xy坐标信息读入
for obj in data["object"]:
xim=float(obj["bndbox"]["xmin"])
xmax=float(obj["bndbox"]["xmax"])
ymin=float(obj["bndbox"]["ymin"])
ymax=float(obj["bndbox"]["ymax"])
#将xy坐标信息存入boxes,一个标签信息中可能有多个objext所以boxes也有多个
boxes.append([xim,ymin,xmax,ymax])
#将该标签中每一个object的name对应json文件中的key值 sheet 1
labels.append(self.class_dict[obj["name"]])
#obj里的difficult代表目标是否重叠,可以简单理解为目标难不难检测
iscrowd.append(int(obj["difficult"]))
boxes=torch.as_tensor(boxes,dtype=torch.float32)
labels=torch.as_tensor(labels,dtype=torch.int64)
iscrowd=torch.as_tensor(iscrowd,dtype=torch.int64)
image_id=torch.tensor([idx])
#boxes[:,3]第三个也就是ymax,计算一下目标面积
area=(boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])
#将上述信息打包到target字典里
target={}
target["boxes"]=boxes
target["labels"]=labels
target["image_id"]=image_id
target["area"]=area
target["iscrowd"]=iscrowd
if self.transforms is not None:
image,target=self.transforms(image,target)
return image,target
def get_height_and_width(self,idx):
xml_path=self.xml_list[idx]
with open(xml_path,encoding="UTF-8") as fid:
xml_str=fid.read()
xml=etree.fromstring(xml_str)
data=self.parse_xml_to_dict(xml)["annotation"]
data_height=int(data["size"]["height"])
data_width=int(data["size"]["width"])
return data_height,data_width
def parse_xml_to_dict(self, xml):
#将xml文件解析成字典形式,递归
if len(xml)==0:
return {xml.tag:xml.text}
result={}
for child in xml:
child_result=self.parse_xml_to_dict(child)
if child.tag!="object":
result[child.tag]=child_result[child.tag]
else:
if child.tag not in result:
result[child.tag]=[]
result[child.tag].append(child_result[child.tag])
return {xml.tag:result}
首先定义一个继承自nn.Module的类,里面的数据集类型是pascolvoc的文件格式,按照如下格式划分
图片用labelimg打标签,做好准备工作后就可以在自己的类中定义路径。
用以下代码划分train.txt和val.txt
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2022/5/15 10:28
# @Author : 半岛铁盒
# @File : spiltdataset.py
# @Software: win10 python3.6
import os
import random
#划分训练集和验证集
#定义自己的标件文件路径
file_path="VOC/VOCdevkit/VOC2007/Annotations"
# os.listdir(path) 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中
# print(os.listdir(file_path))
#用.分割,取出第0个元素也就是标签名称
file_name=sorted([file.split('.')[0] for file in os.listdir(file_path)])
# print(file_name)
file_nums=len(file_name)
# print(file_nums)
# andom.sample,多用于截取列表的指定长度的随机数,但不改变列表本身的顺序
vate=0.5
#从文件个数中抽取总数量的一半,作为验证集的索引
val_index=random.sample(range(0, file_nums),k=int(file_nums*vate))
print(val_index)
#定义两个空的训练集和验证集列表
train_files=[]
val_files=[]
# enumerate将其组成一个索引序列,利用它可以同时获得索引和值
for index,file_name in enumerate(file_name):
#如果索引值在验证集索引里,就把文件加入到验证集列表,否则加入训练集列表
if index in val_index:
val_files.append(file_name)
else:
train_files.append(file_name)
#"x"写模式,新建一个文件
#把train_files和val_files放入新建的txt文件内
train_f=open("train.txt","x")
val_f=open("val.txt","x")
train_f.write("\n".join(train_files))
val_f.write("\n".join(val_files))
本继承自Module的数据集类需要实现两个方法,一是__len__返回数据样本数,二是__getitem__,通过索引返回图片信息和一些列打包好的target信息,其中包括bndbox,也就是xml文件中的xmin,xmax,ymin,ymax;lables,标签对应信息;image_id,图片的索引信息;area,图像的面积(ymax-ymin)*(xmax-xmin);iscrowd,xml文件中difficult代表目标是否重叠,可以简单理解为目标难不难检测,iscrowd为0则不重叠。
其次,根据自己的识别类别更改json文件,以下是我要识别的四类
{ "sheet": 1, "rod": 2, "gelatinous": 3, "Wire": 4 }
通过parse_xml_to_dict这个方法递归xml文件信息读取xml文件中的图片和目标信息。
定义完自己的数据集后就可以使用自己的数据集了。
作者还附了一个画框的代码
import transforms
import torchvision.transforms as ts
from draw_box_utils import draw_objs,draw_masks,draw_text
import numpy
import random
category_index={}
json_file=open("pascal_voc_classes.json","r")
class_dict=json.load(json_file)
category_index={str(values):key for key,values in class_dict.items()}
# print(category_index)
data_transforms={
"train": transforms.Compose([transforms.ToTensor(),
transforms.RandomHorizontalFlip(0.5)]),
"val": transforms.Compose([transforms.ToTensor()])
}
train_data_set=VOCDataSet(voc_root="./VOC",transforms=data_transforms["train"],train_set=True)
# print(train_data_set.__len__())
# print(train_data_set.__getitem__(150))
x=len(train_data_set)
# print(random.sample(range(0,x),k=5))
# #随机取5张图
for index in random.sample(range(0,x),k=5):
# print(train_data_set[index])
img,target=train_data_set[index]
# print(img,target)
trans=ts.ToPILImage()
img =trans(img)
draw_objs(img,
target["boxes"].numpy(),
target["labels"].numpy(),
#分数都为1
scores=numpy.array([1 for i in range(len(target["labels"].numpy()))]),
category_index=category_index,
box_thresh=0.5,
line_thickness=5)
plt.imshow(img)
plt.show()
drawbox部分的代码在up主视频下有。