mmdetection3d自定义数据集指南:标注格式与训练流程

mmdetection3d自定义数据集指南:标注格式与训练流程

【免费下载链接】mmdetection3d OpenMMLab's next-generation platform for general 3D object detection. 【免费下载链接】mmdetection3d 项目地址: https://gitcode.com/gh_mirrors/mm/mmdetection3d

1. 引言:3D检测中的数据挑战

在自动驾驶(Autonomous Driving)和机器人视觉(Robotics Vision)领域,3D目标检测(3D Object Detection)的性能高度依赖标注数据的质量和数量。然而,公开数据集如KITTI、NuScenes往往存在场景单一、标注成本高、覆盖范围有限等问题。根据OpenMMLab社区统计,超过68%的工业用户需要适配自定义场景数据(如室内仓储、矿区车辆、特殊机械臂等)。

本文将系统讲解如何基于mmdetection3d构建自定义数据集,解决三大核心痛点:

  • 标注格式设计与转换
  • 数据集类(Dataset Class)开发
  • 训练配置与性能验证

2. 3D标注格式设计规范

2.1 核心数据结构

3D检测标注需包含空间位置、尺寸、姿态等关键信息。推荐采用JSON行格式(每行一个样本),便于流式读取:

{
  "sample_id": "scene_0001_frame_0042",
  "timestamp": 1623456789.123,
  "cam_intrinsic": [  // 相机内参矩阵 (3x3)
    [1200.0, 0.0, 640.0],
    [0.0, 1200.0, 360.0],
    [0.0, 0.0, 1.0]
  ],
  "annos": [
    {
      "bbox_3d": [1.5, -0.8, 10.2, 4.7, 1.8, 1.5, -1.57],  // x,y,z,l,w,h,yaw
      "label": "truck",
      "difficulty": 1,
      "num_lidar_pts": 235,  // 点云覆盖数量(用于过滤低质量标注)
      "bbox_2d": [320, 240, 480, 360]  // 可选2D辅助标注
    }
  ],
  "calib": {  // 传感器外参(如激光雷达到相机)
    "lidar2cam": [
      [0.998, 0.012, -0.050, 0.3],
      [-0.010, 0.999, 0.020, 0.1],
      [0.050, -0.018, 0.998, 0.2],
      [0, 0, 0, 1]
    ]
  }
}

2.2 坐标系统约定

必须严格遵循右手坐标系(Right-Hand Coordinate System):

  • x轴:水平向右(相机视角)
  • y轴:垂直向下(相机视角)
  • z轴:沿光轴向前(相机视角)
  • 旋转角:绕y轴逆时针为正(yaw角)

⚠️ 常见错误:激光雷达(LiDAR)坐标系与相机坐标系混淆。建议在标注文件中明确声明传感器类型。

2.3 数据目录结构

推荐采用KITTI风格目录,便于复用现有数据加载逻辑:

custom_dataset/
├── training/
│   ├── image_2/          # 相机图像 (1280x720 PNG)
│   ├── velodyne/         # 激光雷达点云 (二进制格式)
│   ├── label_2/          # 3D标注文件 (JSON行格式)
│   └── calib/            # 校准参数 (JSON格式)
└── ImageSets/
    ├── train.txt         # 训练集样本ID列表
    └── val.txt           # 验证集样本ID列表

3. 数据集类开发指南

3.1 基类继承关系

mmdetection3d的数据集系统基于模块化设计,自定义数据集需继承BaseDataset并实现核心方法:

mermaid

3.2 核心方法实现

3.2.1 数据加载(load_data_list)
from mmdet3d.datasets import BaseDataset

class Custom3DDataset(BaseDataset):
    METAINFO = {
        'classes': ('truck', 'forklift', 'pallet'),  # 自定义类别
        'palette': [(255, 0, 0), (0, 255, 0), (0, 0, 255)]
    }

    def load_data_list(self):
        """加载标注文件并生成数据列表"""
        data_list = []
        with open(self.ann_file, 'r') as f:
            for line in f:
                data = json.loads(line)
                data_list.append({
                    'sample_id': data['sample_id'],
                    'lidar_path': os.path.join(
                        self.data_root, 'training/velodyne', 
                        f"{data['sample_id']}.pcd"
                    ),
                    'cam_path': os.path.join(
                        self.data_root, 'training/image_2', 
                        f"{data['sample_id']}.png"
                    ),
                    'ann_info': data['annos'],
                    'calib': data['calib']
                })
        return data_list
3.2.2 标注解析(parse_ann_info)
    def parse_ann_info(self, info: dict) -> dict:
        """解析标注信息为模型输入格式"""
        ann_info = super().parse_ann_info(info)
        gt_bboxes_3d = []
        gt_labels_3d = []
        
        for anno in info['ann_info']:
            # 将标注坐标转换为mmdet3d标准格式 (x,y,z,l,w,h,yaw)
            x, y, z, l, w, h, yaw = anno['bbox_3d']
            gt_bboxes_3d.append([x, y, z, l, w, h, yaw])
            
            # 类别映射(字符串转索引)
            cls_id = self.METAINFO['classes'].index(anno['label'])
            gt_labels_3d.append(cls_id)
        
        ann_info['gt_bboxes_3d'] = np.array(gt_bboxes_3d, dtype=np.float32)
        ann_info['gt_labels_3d'] = np.array(gt_labels_3d, dtype=np.int32)
        return ann_info

3.3 数据增强适配

为保证自定义数据集与内置数据增强兼容,需确保输出符合Det3DDataSample格式:

    def prepare_data(self, idx: int) -> dict:
        """准备模型输入数据"""
        data_dict = super().prepare_data(idx)
        
        # 添加点云强度特征(若使用)
        if 'intensity' in data_dict['points'].keys():
            data_dict['points'].tensor = np.hstack([
                data_dict['points'].tensor,
                data_dict['points'].intensity[:, None]
            ])
            
        return data_dict

4. 标注格式转换工具

4.1 从COCO格式转换

若已有2D标注,可使用以下脚本升级到3D格式:

# tools/dataset_converters/coco_to_3d.py
import json

def convert_coco_to_3d(coco_json, output_json):
    with open(coco_json, 'r') as f:
        coco_data = json.load(f)
    
    for img in coco_data['images']:
        # 假设3D信息存储在image的extra_fields中
        sample_id = img['id']
        calib = img['calib']
        # ... 转换逻辑 ...

if __name__ == '__main__':
    convert_coco_to_3d(
        'coco_annotations.json', 
        'custom_3d_annotations.jsonl'
    )

4.2 点云可视化验证

使用open3d工具验证标注质量:

import open3d as o3d
import numpy as np

def visualize_annotation(pcd_path, bbox_3d):
    pcd = o3d.io.read_point_cloud(pcd_path)
    bbox = o3d.geometry.OrientedBoundingBox(
        center=bbox_3d[:3],
        extent=bbox_3d[3:6],
        R=o3d.geometry.get_rotation_matrix_from_yxz([0, bbox_3d[6], 0])
    )
    bbox.color = [1, 0, 0]  # 红色边界框
    o3d.visualization.draw_geometries([pcd, bbox])

# 使用示例
visualize_annotation('sample.pcd', [1.5, -0.8, 10.2, 4.7, 1.8, 1.5, -1.57])

5. 训练配置与性能调优

5.1 配置文件编写

创建configs/custom/custom_pointpillars.py

_base_ = [
    '../_base_/datasets/custom-3d.py',
    '../_base_/models/pointpillars_hv_secfpn.py',
    '../_base_/schedules/cyclic-40e.py',
    '../_base_/default_runtime.py'
]

# 数据集配置
dataset_type = 'Custom3DDataset'
data_root = '/data/custom_dataset/'
class_names = ('truck', 'forklift', 'pallet')

train_dataloader = dict(
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='training/label_2/train.jsonl',
        pipeline=train_pipeline
    )
)

# 模型调整(适配小样本)
model = dict(
    bbox_head=dict(
        num_classes=3,
        anchor_generator=dict(
            ranges=[[0, -40, -0.6, 70.4, 40, -0.6]]  # 调整Anchor范围
        )
    )
)

# 学习率调整(小数据集)
param_scheduler = dict(
    type='CosineAnnealingLR',
    T_max=20,  # 减少周期
    eta_min=1e-5
)

5.2 关键超参数调优

针对自定义数据集的常见问题,调整以下参数:

问题场景推荐配置原理
样本数量<1kload_from='pointpillars_pretrain.pth'迁移学习初始化
类别不平衡sampler=dict(type='ClassBalancedSampler')类别均衡采样
标注噪声大bbox_coder=dict(encode_size=True)尺寸编码增强鲁棒性

5.3 评估指标选择

根据场景特性选择评估指标:

val_evaluator = dict(
    type='Custom3DEvaluator',
    metric=['mAP', 'NDS'],  # 平均精度与归一化检测分数
    iou_thr=[0.5, 0.25, 0.75]  # 多阈值评估
)

6. 完整训练流程

6.1 数据准备命令

# 1. 克隆仓库
git clone https://gitcode.com/gh_mirrors/mm/mmdetection3d.git
cd mmdetection3d

# 2. 安装依赖
pip install -r requirements/runtime.txt
pip install -v -e .

# 3. 生成数据集索引
python tools/create_data.py custom --root-path /data/custom_dataset \
    --out-dir /data/custom_dataset --extra-tag custom

6.2 启动训练

# 单卡训练(调试)
python tools/train.py configs/custom/custom_pointpillars.py

# 多卡训练(生产)
bash tools/dist_train.sh configs/custom/custom_pointpillars.py 8

6.3 可视化与调试

使用TensorBoard监控训练过程:

tensorboard --logdir work_dirs/custom_pointpillars/

关键监控指标:

  • 3D mAP@0.5(核心检测指标)
  • 点云召回率(Point Cloud Recall)
  • 学习率曲线(检查是否过拟合)

7. 常见问题解决方案

7.1 数据加载错误

症状FileNotFoundError: velodyne file not found
解决:检查load_data_list中的路径拼接逻辑,推荐使用pathlib

from pathlib import Path
lidar_path = Path(self.data_root) / 'training/velodyne' / f"{sample_id}.pcd"

7.2 模型不收敛

症状:loss停留在10+且mAP=0
解决:检查:

  1. 标注坐标是否符合右手坐标系
  2. gt_bboxes_3d格式是否为(x,y,z,l,w,h,yaw)
  3. 学习率是否适配 batch_size(推荐lr=0.001 * batch_size/16

7.3 点云与图像不同步

症状:检测框漂浮或偏移
解决:使用校准矩阵可视化验证:

# 验证点云投影到图像
def project_lidar_to_cam(points, calib):
    R = np.array(calib['lidar2cam'][:3,:3])
    T = np.array(calib['lidar2cam'][:3,3])
    return points @ R.T + T

8. 总结与进阶方向

本文详细介绍了mmdetection3d自定义数据集的全流程,包括标注格式设计、数据集类开发、训练配置与调试技巧。关键要点:

  1. 格式兼容性:优先采用JSON行格式,便于扩展与复用
  2. 模块化设计:通过继承BaseDataset减少重复代码
  3. 小样本策略:迁移学习+数据增强组合提升性能

进阶方向:

  • 半监督标注(Semi-supervised Annotation):结合mmssl
  • 动态点云(Dynamic Point Cloud):适配4D雷达数据
  • 多模态融合(Multi-modal Fusion):添加红外/毫米波雷达数据

通过灵活运用mmdetection3d的数据接口,开发者可快速将3D检测技术部署到自定义场景,加速从算法研究到工业落地的转化过程。

【免费下载链接】mmdetection3d OpenMMLab's next-generation platform for general 3D object detection. 【免费下载链接】mmdetection3d 项目地址: https://gitcode.com/gh_mirrors/mm/mmdetection3d

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值