【mmdetection】使用kitti数据集进行训练

本文详细介绍了如何在MMDetection框架中配置和使用Kitti数据集,包括环境配置、数据集准备、仓库文件修改、训练流程、结果可视化以及相关参考。具体步骤涉及Kitti数据集的目录结构、KittiDataset类的定义、配置文件的更新、评估指标的修改等,同时提供了训练命令和可视化代码。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

一、环境配置

二、Kitti数据集准备

三、仓库中需要修改的文件

3.1 mmdet/datasets中添加kitti.py,内容如下

3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出

3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下

3.4 修改mmdet\core\evaluation文件夹中的class_names.py文件

3.5 修改mmdet\core\evaluation/__init__.py文件,修改位置已注释标出

3.6 修改configs文件夹中需要使用的配置文件

四、训练

五、可视化

5.1 修改mmdet\apis文件夹中的inference.py文件

5.2 使用visualization.py(见下)可视化。(注:需要将visualization.py放在到mmdetection目录下)

5.3 或者使用DetVisGUI可视化

六、参考

七、附录

7.1 kitti标签类别合并

7.2 kitti转voc(可能有问题,不建议使用)


写在前面:官方给了一个demo程序将Kitti转为COCO格式,但是加载数据、修改配置、训练、测试、可视化这些东西都都放在一起总觉得不舒服,用那个比较好,于是把官方的示例改成了一个新的数据集kitti。

一、环境配置

mmdet 2.7.0

mmcv 1.2.1

克隆仓库中的源码,并在目录下创建data文件夹。按照get_started.md文件进行配置,不再赘述。

二、Kitti数据集准备

按照以下文件夹结构准备数据。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── kitti
│   │   ├── training
│   │   │   ├── image_2
│   │   │   ├── label_2
│   │   ├── train.txt
│   │   ├── val.txt
│   │   ├── trainval.txt

三、仓库中需要修改的文件

3.1 mmdet/datasets中添加kitti.py,内容如下

import os.path as osp

import mmcv
import numpy as np

from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset


@DATASETS.register_module()
class KittiDataset(CustomDataset):
    CLASSES = ('Car', 'Pedestrian', 'Cyclist')

    def load_annotations(self, ann_file):
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        # load image list from file
        image_list = mmcv.list_from_file(self.ann_file)

        data_infos = []
        # convert annotations to middle format
        for image_id in image_list:
            filename = f'{self.img_prefix}/{image_id}.png'
            image = mmcv.imread(filename)
            height, width = image.shape[:2]

            data_info = dict(filename=f'{image_id}.png', width=width, height=height)

            # load annotations
            label_prefix = self.img_prefix.replace('image_2', 'label_2')
            lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))

            content = [line.strip().split(' ') for line in lines]
            bbox_names = [x[0] for x in content]
            bboxes = [[float(info) for info in x[4:8]] for x in content]

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            # filter 'DontCare'
            for bbox_name, bbox in zip(bbox_names, bboxes):
                if bbox_name in cat2label:
                    gt_labels.append(cat2label[bbox_name])
                    gt_bboxes.append(bbox)
                else:
                    gt_labels_ignore.append(-1)
                    gt_bboxes_ignore.append(bbox)

            data_anno = dict(
                bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                labels=np.array(gt_labels, dtype=np.long),
                bboxes_ignore=np.array(gt_bboxes_ignore,
                                       dtype=np.float32).reshape(-1, 4),
                labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

            data_info.update(ann=data_anno)
            data_infos.append(data_info)

        return data_infos

3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出

from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cityscapes import CityscapesDataset
from .coco import CocoDataset
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               RepeatDataset)
from .deepfashion import DeepFashionDataset
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
from .utils import replace_ImageToTensor
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset
from .kitti import KittiDataset            #新加

__all__ = [
    #下面的KittiDataset为新加
    'KittiDataset','CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
    'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
    'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
    'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
    'build_dataset', 'replace_ImageToTensor'
]

3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下

# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值