maskrcnn benchmark自定义数据集的方法

本文介绍如何为 Mask R-CNN 框架自定义数据集,包括数据集的放置位置、配置文件的修改、自定义数据加载类的编写及 YAML 配置文件的调整等关键步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

本文系转载,作者:风与树影

maskrcnn benchmark自定义数据集的方法


1、拷贝数据集到根目录的datasets下(和demo同级目录)如

maskrcnn-benchmark/datasets/jinnan/jinnan2_round1_train_20190305

2、修改paths_catalog.py

路径为maskrcnn-benchmark/maskrcnn_benchmark/config/paths_catalog.py

a、在paths_catalog中的DATASETS字典中添加你需要的路径,如

"jinnan_train": {
"img_dir": "jinnan2_round1_train_20190305",
"ann_file": "jinnan2_round1_train_20190305/train_no_poly.json"
},

注意:自定义数据集的话,img_dirann_file会作为形参传到你自己创建的MyDataset类里面

b、修改paths_catalog中部静态函数get(name)方法

添加一个if else,把你创建的数据集相关内容放进去,如

elif "jinnan" in name:  # name对应yaml文件传过来的数据集名字
    data_dir = DatasetCatalog.DATA_DIR
    attrs = DatasetCatalog.DATASETS[name]
    args = dict(
        root=os.path.join(data_dir, attrs["img_dir"]),  # img_dir就是a步骤里面的内容
        ann_file=os.path.join(data_dir, attrs["ann_file"]),  # ann_file就是a步骤里面的内容
    )
    return dict(
        factory="MyDataset",  # 这个MyDataset对应
        args=args,
    )

上面参数解释(主要是MyDataset):

  1. 这个MyDataset就是你自己建的那个类,返回值是image, boxlist, idx,具体实现参考git官网(很容易)

  2. 比如我实现好了MyDataset类,然后这个py文件取名为jinnan.py

  3. 然后放在maskrcnn-benchmark/maskrcnn_benchmark/data/datasets路径下

  4. 接着配置那个目录里面的__init__.py文件,第四行和all最后一个元素是自己加的

from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset
from .jinnan import MyDataset

all = ["COCODataset", "ConcatDataset", "PascalVOCDataset", "MyDataset"]
  1. 注意,实现MyDataset要实现__len____getitem__get_img_info,还有__init__,其中__init__会得到第一个步骤传来的attrs__init__的一个参数参考:
def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None)

不知参数是什么意思得去看maskrcnn-benchmark/maskrcnn_benchmark/data/build.py

3、修改yaml文件

主要是修改数据load部分

MODEL:
  MASK_ON: False
DATASETS:
  TRAIN: ("jinnan_train", "jinnan_val")
  TEST: ("jinnan_test",)

上面三个值都是自己设的,其实有用的就jinnan_train,当然首先重要的是要把MASK_ON关闭。

4、 我自己写的数据加载的凌乱的参考

maskrcnn-benchmark/maskrcnn_benchmark/data/datasets/jinnan.py

from maskrcnn_benchmark.structures.bounding_box import BoxList
from PIL import Image
import os
import json
import torch

class MyDataset(object):
    def __init__(self,ann_file=None, root=None, remove_images_without_annotations=None, transforms=None):
        # as you would do normally

        self.transforms = transforms

        self.train_path = root
        with open(ann_file, 'r') as f:
            self.data = json.load(f)

        self.idxs = list(range(len(self.data['images'])))  # 看要训练的图像有多少张,把id用个列表存储方便随机
        self.bbox_label = {}
        for anno in self.data['annotations']:
            bbox = anno['bbox']
            bbox[2] += bbox[0]
            bbox[3] += bbox[1]
            cate = anno['category_id']
            image_id = anno['image_id']
            if not image_id in self.bbox_label:
                self.bbox_label[image_id] = [[bbox], [cate]]
            else:
                self.bbox_label[image_id][0].append(bbox)
                self.bbox_label[image_id][1].append(cate)

    def __getitem__(self, idx):
        # load the image as a PIL Image
        idx = self.idxs[idx % len(self.data['images'])]
        if idx not in self.bbox_label:  # 210, 262, 690, 855 have no bbox
            idx += 1
        path = self.data['images'][idx]['file_name']

        folder = 'restricted' if idx < 981 else 'normal'

        image = Image.open(os.path.join(self.train_path, folder, path)).convert('RGB')
        # load the bounding boxes as a list of list of boxes
        # in this case, for illustrative purposes, we use
        # x1, y1, x2, y2 order.
        # boxes = [[0, 0, 10, 10], [10, 20, 50, 50]]
        boxes = self.bbox_label[idx][0]
        category = self.bbox_label[idx][-1]

        # and labels
        labels = torch.tensor(category)

        # create a BoxList from the boxes
        boxlist = BoxList(boxes, image.size, mode="xyxy")
        # add the labels to the boxlist
        boxlist.add_field("labels", labels)

        if self.transforms:
            image, boxlist = self.transforms(image, boxlist)

        # return the image, the boxlist and the idx in your dataset
        return image, boxlist, idx
    def __len__(self):
        return len(self.data['images'])

    def get_img_info(self, idx):
        idx = self.idxs[idx % len(self.data['images'])]
        height = self.data['images'][idx]['height']
        width = self.data['images'][idx]['width']
        # get img_height and img_width. This is used if
        # we want to split the batches according to the aspect ratio
        # of the image, as it can be more efficient than loading the
        # image from disk
        return {"height": height, "width": width}
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值