pytorch创健自己的数据集

本文详细介绍了如何自定义VOC数据集,包括XML文件解析、图片和目标信息提取,以及如何使用Faster R-CNN模型进行对象检测。通过实例代码展示了如何划分训练集和验证集,以及数据预处理和模型应用的过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文的代码是参考b站up主霹雳吧啦Wz1.2Faster RCNN源码解析(pytorch)_哔哩哔哩_bilibili

最终目的是复现fasterrcnn网络

先附上整体代码

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2022/5/15 10:28
# @Author  : 半岛铁盒
# @File    : mydataset.py
# @Software: win10  python3.6
#定义一个自己的数据集
import random

from torch.utils.data import Dataset
import os
import json
from PIL import Image
from lxml import etree
import torch
import matplotlib.pyplot as plt

#继承自torch.utils.data.Dataset
class VOCDataSet(Dataset):
    def __init__(self,voc_root,transforms,train_set=True):
        self.root=os.path.join(voc_root,"VOCdevkit","VOC2007")
        self.img_root=os.path.join(self.root,"JPEGImages")
        self.Annotations_root=os.path.join(self.root,"Annotations")
        #如果是训练集则训练集,否则验证集
        if train_set==True:
            txt_list=os.path.join(self.root,"ImageSets","Main","train.txt")
        else:
            txt_list=os.path.join(self.root,"ImageSets","Main","val.txt")

        #readlines()一次性先将整个文件内容按行读取完
        #strip() 方法用于移除字符串头尾指定的字符(默认为空格或换行符)或字符序列,此处用来去除换行符
        #xml_list代表存放训练集或验证集的txt文件
        with open(txt_list) as read:
            self.xml_list=[os.path.join(self.Annotations_root,line.strip()+".xml") for line in read.readlines()]

        json_file=open("./pascal_voc_classes.json","r")
        self.class_dict=json.load(json_file)
        self.transforms=transforms

        # print(self.xml_list[154])

    def __len__(self):
        return len(self.xml_list)

    def __getitem__(self, idx):
        #返回索引对应的xml文件
        xml_path=self.xml_list[idx]
        with open(xml_path,encoding="UTF-8") as fid:
            #读取xml文件
            xml_str=fid.read()
        #etree.fromstring()(将字符串转化为Element对象) 用etree来读取xml文件
        xml=etree.fromstring(xml_str)
        #parse_xml_to_dict返回标签对应的xml信息
        data=self.parse_xml_to_dict(xml)["annotation"]
        #如果filename不是jpg后缀就把它改成jpg后缀
        file_name=data["filename"].split(".")[0]+".jpg"
        ima_path=os.path.join(self.img_root,file_name)
        image=Image.open(ima_path)

        boxes=[]
        labels=[]
        iscrowd=[]
        #将标签信息中objec里的bndbox内的xy坐标信息读入
        for obj in data["object"]:
            xim=float(obj["bndbox"]["xmin"])
            xmax=float(obj["bndbox"]["xmax"])
            ymin=float(obj["bndbox"]["ymin"])
            ymax=float(obj["bndbox"]["ymax"])
            #将xy坐标信息存入boxes,一个标签信息中可能有多个objext所以boxes也有多个
            boxes.append([xim,ymin,xmax,ymax])

            #将该标签中每一个object的name对应json文件中的key值 sheet 1
            labels.append(self.class_dict[obj["name"]])
            #obj里的difficult代表目标是否重叠,可以简单理解为目标难不难检测
            iscrowd.append(int(obj["difficult"]))
        boxes=torch.as_tensor(boxes,dtype=torch.float32)
        labels=torch.as_tensor(labels,dtype=torch.int64)
        iscrowd=torch.as_tensor(iscrowd,dtype=torch.int64)
        image_id=torch.tensor([idx])
        #boxes[:,3]第三个也就是ymax,计算一下目标面积
        area=(boxes[:,3]-boxes[:,1])*(boxes[:,2]-boxes[:,0])
        #将上述信息打包到target字典里
        target={}
        target["boxes"]=boxes
        target["labels"]=labels
        target["image_id"]=image_id
        target["area"]=area
        target["iscrowd"]=iscrowd

        if self.transforms is not None:
            image,target=self.transforms(image,target)

        return image,target

    def get_height_and_width(self,idx):
        xml_path=self.xml_list[idx]
        with open(xml_path,encoding="UTF-8") as fid:
            xml_str=fid.read()
        xml=etree.fromstring(xml_str)
        data=self.parse_xml_to_dict(xml)["annotation"]
        data_height=int(data["size"]["height"])
        data_width=int(data["size"]["width"])
        return data_height,data_width


    def parse_xml_to_dict(self, xml):
        #将xml文件解析成字典形式,递归
        if len(xml)==0:
            return {xml.tag:xml.text}
        result={}
        for child in xml:
            child_result=self.parse_xml_to_dict(child)
            if child.tag!="object":
                result[child.tag]=child_result[child.tag]
            else:
                if child.tag not in result:
                    result[child.tag]=[]
                result[child.tag].append(child_result[child.tag])
        return {xml.tag:result}

首先定义一个继承自nn.Module的类,里面的数据集类型是pascolvoc的文件格式,按照如下格式划分

 图片用labelimg打标签,做好准备工作后就可以在自己的类中定义路径。

用以下代码划分train.txt和val.txt

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2022/5/15 10:28
# @Author  : 半岛铁盒
# @File    : spiltdataset.py
# @Software: win10  python3.6
import os
import random
#划分训练集和验证集

#定义自己的标件文件路径
file_path="VOC/VOCdevkit/VOC2007/Annotations"
# os.listdir(path) 返回指定路径下所有文件和文件夹的名字,并存放于一个列表中
# print(os.listdir(file_path))
#用.分割,取出第0个元素也就是标签名称
file_name=sorted([file.split('.')[0] for file in os.listdir(file_path)])
# print(file_name)
file_nums=len(file_name)
# print(file_nums)

# andom.sample,多用于截取列表的指定长度的随机数,但不改变列表本身的顺序
vate=0.5

#从文件个数中抽取总数量的一半,作为验证集的索引
val_index=random.sample(range(0, file_nums),k=int(file_nums*vate))
print(val_index)


#定义两个空的训练集和验证集列表
train_files=[]
val_files=[]

# enumerate将其组成一个索引序列,利用它可以同时获得索引和值
for index,file_name in enumerate(file_name):
    #如果索引值在验证集索引里,就把文件加入到验证集列表,否则加入训练集列表
    if index in val_index:
        val_files.append(file_name)
    else:
        train_files.append(file_name)

#"x"写模式,新建一个文件
#把train_files和val_files放入新建的txt文件内
train_f=open("train.txt","x")
val_f=open("val.txt","x")
train_f.write("\n".join(train_files))
val_f.write("\n".join(val_files))

本继承自Module的数据集类需要实现两个方法,一是__len__返回数据样本数,二是__getitem__,通过索引返回图片信息和一些列打包好的target信息,其中包括bndbox,也就是xml文件中的xmin,xmax,ymin,ymax;lables,标签对应信息;image_id,图片的索引信息;area,图像的面积(ymax-ymin)*(xmax-xmin);iscrowd,xml文件中difficult代表目标是否重叠,可以简单理解为目标难不难检测,iscrowd为0则不重叠。

其次,根据自己的识别类别更改json文件,以下是我要识别的四类

{
    "sheet": 1,
    "rod": 2,
    "gelatinous": 3,
    "Wire": 4
}

通过parse_xml_to_dict这个方法递归xml文件信息读取xml文件中的图片和目标信息。

定义完自己的数据集后就可以使用自己的数据集了。

作者还附了一个画框的代码

import transforms
import torchvision.transforms as ts
from draw_box_utils import draw_objs,draw_masks,draw_text
import numpy
import random
category_index={}
json_file=open("pascal_voc_classes.json","r")
class_dict=json.load(json_file)
category_index={str(values):key for key,values in class_dict.items()}
# print(category_index)

data_transforms={
    "train": transforms.Compose([transforms.ToTensor(),
                                transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}


train_data_set=VOCDataSet(voc_root="./VOC",transforms=data_transforms["train"],train_set=True)

# print(train_data_set.__len__())
# print(train_data_set.__getitem__(150))
x=len(train_data_set)

# print(random.sample(range(0,x),k=5))

# #随机取5张图
for index in random.sample(range(0,x),k=5):
    # print(train_data_set[index])
    img,target=train_data_set[index]
    # print(img,target)
    trans=ts.ToPILImage()
    img =trans(img)
    draw_objs(img,
              target["boxes"].numpy(),
              target["labels"].numpy(),
              #分数都为1
              scores=numpy.array([1 for i in range(len(target["labels"].numpy()))]),
              category_index=category_index,
              box_thresh=0.5,
              line_thickness=5)


    plt.imshow(img)
    plt.show()

drawbox部分的代码在up主视频下有。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值