【mmdetection】使用kitti数据集进行训练

本文详细介绍了如何在MMDetection框架中配置和使用Kitti数据集,包括环境配置、数据集准备、仓库文件修改、训练流程、结果可视化以及相关参考。具体步骤涉及Kitti数据集的目录结构、KittiDataset类的定义、配置文件的更新、评估指标的修改等,同时提供了训练命令和可视化代码。

目录

一、环境配置

二、Kitti数据集准备

三、仓库中需要修改的文件

3.1 mmdet/datasets中添加kitti.py,内容如下

3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出

3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下

3.4 修改mmdet\core\evaluation文件夹中的class_names.py文件

3.5 修改mmdet\core\evaluation/__init__.py文件,修改位置已注释标出

3.6 修改configs文件夹中需要使用的配置文件

四、训练

五、可视化

5.1 修改mmdet\apis文件夹中的inference.py文件

5.2 使用visualization.py(见下)可视化。(注:需要将visualization.py放在到mmdetection目录下)

5.3 或者使用DetVisGUI可视化

六、参考

七、附录

7.1 kitti标签类别合并

7.2 kitti转voc(可能有问题,不建议使用)


写在前面:官方给了一个demo程序将Kitti转为COCO格式,但是加载数据、修改配置、训练、测试、可视化这些东西都都放在一起总觉得不舒服,用那个比较好,于是把官方的示例改成了一个新的数据集kitti。

一、环境配置

mmdet 2.7.0

mmcv 1.2.1

克隆仓库中的源码,并在目录下创建data文件夹。按照get_started.md文件进行配置,不再赘述。

二、Kitti数据集准备

按照以下文件夹结构准备数据。

mmdetection
├── mmdet
├── tools
├── configs
├── data
│   ├── kitti
│   │   ├── training
│   │   │   ├── image_2
│   │   │   ├── label_2
│   │   ├── train.txt
│   │   ├── val.txt
│   │   ├── trainval.txt

三、仓库中需要修改的文件

3.1 mmdet/datasets中添加kitti.py,内容如下

import os.path as osp

import mmcv
import numpy as np

from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset


@DATASETS.register_module()
class KittiDataset(CustomDataset):
    CLASSES = ('Car', 'Pedestrian', 'Cyclist')

    def load_annotations(self, ann_file):
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        # load image list from file
        image_list = mmcv.list_from_file(self.ann_file)

        data_infos = []
        # convert annotations to middle format
        for image_id in image_list:
            filename = f'{self.img_prefix}/{image_id}.png'
            image = mmcv.imread(filename)
            height, width = image.shape[:2]

            data_info = dict(filename=f'{image_id}.png', width=width, height=height)

            # load annotations
            label_prefix = self.img_prefix.replace('image_2', 'label_2')
            lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))

            content = [line.strip().split(' ') for line in lines]
            bbox_names = [x[0] for x in content]
            bboxes = [[float(info) for info in x[4:8]] for x in content]

            gt_bboxes = []
            gt_labels = []
            gt_bboxes_ignore = []
            gt_labels_ignore = []

            # filter 'DontCare'
            for bbox_name, bbox in zip(bbox_names, bboxes):
                if bbox_name in cat2label:
                    gt_labels.append(cat2label[bbox_name])
                    gt_bboxes.append(bbox)
                else:
                    gt_labels_ignore.append(-1)
                    gt_bboxes_ignore.append(bbox)

            data_anno = dict(
                bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                labels=np.array(gt_labels, dtype=np.long),
                bboxes_ignore=np.array(gt_bboxes_ignore,
                                       dtype=np.float32).reshape(-1, 4),
                labels_ignore=np.array(gt_labels_ignore, dtype=np.long))

            data_info.update(ann=data_anno)
            data_infos.append(data_info)

        return data_infos

3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出

from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cityscapes import CityscapesDataset
from .coco import CocoDataset
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
                               RepeatDataset)
from .deepfashion import DeepFashionDataset
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
from .utils import replace_ImageToTensor
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset
from .kitti import KittiDataset            #新加

__all__ = [
    #下面的KittiDataset为新加
    'KittiDataset','CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
    'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
    'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
    'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
    'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
    'build_dataset', 'replace_ImageToTensor'
]

3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下

# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
### 使用 mmdetection3d 训练 Kitti 数据集 #### 准备工作 为了使用 `mmdetection3d` 训练 KITTI 数据集,需先完成环境搭建并准备好数据。具体操作包括安装依赖项以及下载和处理KITTI数据集。 #### 配置文件设置 在配置文件中定义模型架构和其他参数对于训练至关重要。针对KITTI数据集,在`configs/_base_/datasets/`目录下创建名为`kitti_detection.py`的文件来指定数据加载方式[^4]: ```python dataset_type = 'KittiDataset' data_root = 'data/kitti/' class_names = ['Pedestrian', 'Cyclist', 'Car'] input_modality = dict(use_lidar=False, use_camera=True) train_pipeline = [ ... ] test_pipeline = [ ... ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type='RepeatDataset', times=2, dataset=dict( type=dataset_type, data_root=data_root, ann_file=data_root + 'kitti_infos_train.pkl', split='training', pts_prefix='velodyne_reduced', pipeline=train_pipeline, modality=input_modality, classes=class_names, test_mode=False)), val=dict( type=dataset_type, data_root=data_root, ann_file=data_root + 'kitti_infos_val.pkl', split='training', pts_prefix='velodyne_reduced', pipeline=test_pipeline, modality=input_modality, classes=class_names, test_mode=True), test=dict( type=dataset_type, data_root=data_root, ann_file=data_root + 'kitti_infos_test.pkl', split='testing', pts_prefix='velodyne_reduced', pipeline=test_pipeline, modality=input_modality, classes=class_names, test_mode=True)) ``` #### 开始训练 当一切准备就绪后,可以通过命令行启动训练过程。这里以PointPillars为例说明如何执行训练任务: ```bash cd /path/to/mmdetection3d/ python tools/train.py configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py --work-dir work_dirs/pointpillars_kitti ``` 此命令会调用`train.py`脚本,并传入特定于PointPillars模型及其超参设定的配置文件路径作为输入参数[^2]。 #### 分布式训练选项 如果希望利用多GPU加速训练速度,则可以考虑采用分布式模式。这可通过修改上述命令中的Python模块实现: ```bash PYTHONPATH=$PWD:$PYTHONPATH \ torchrun --nproc_per_node=NUM_GPUS_TO_USE \ tools/train.py configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py \ --launcher pytorch --work-dir work_dirs/pointpillars_kitti_distributed ``` 其中`NUM_GPUS_TO_USE`应替换为实际可用的GPU数量[^3]。
评论 7
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值