目录
3.1 mmdet/datasets中添加kitti.py,内容如下
3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出
3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下
3.4 修改mmdet\core\evaluation文件夹中的class_names.py文件
3.5 修改mmdet\core\evaluation/__init__.py文件,修改位置已注释标出
5.1 修改mmdet\apis文件夹中的inference.py文件
5.2 使用visualization.py(见下)可视化。(注:需要将visualization.py放在到mmdetection目录下)
写在前面:官方给了一个demo程序将Kitti转为COCO格式,但是加载数据、修改配置、训练、测试、可视化这些东西都都放在一起总觉得不舒服,用那个比较好,于是把官方的示例改成了一个新的数据集kitti。
一、环境配置
mmdet 2.7.0
mmcv 1.2.1
克隆仓库中的源码,并在目录下创建data文件夹。按照get_started.md文件进行配置,不再赘述。
二、Kitti数据集准备
按照以下文件夹结构准备数据。
mmdetection
├── mmdet
├── tools
├── configs
├── data
│ ├── kitti
│ │ ├── training
│ │ │ ├── image_2
│ │ │ ├── label_2
│ │ ├── train.txt
│ │ ├── val.txt
│ │ ├── trainval.txt
三、仓库中需要修改的文件
3.1 mmdet/datasets中添加kitti.py,内容如下
import os.path as osp
import mmcv
import numpy as np
from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset
@DATASETS.register_module()
class KittiDataset(CustomDataset):
CLASSES = ('Car', 'Pedestrian', 'Cyclist')
def load_annotations(self, ann_file):
cat2label = {k: i for i, k in enumerate(self.CLASSES)}
# load image list from file
image_list = mmcv.list_from_file(self.ann_file)
data_infos = []
# convert annotations to middle format
for image_id in image_list:
filename = f'{self.img_prefix}/{image_id}.png'
image = mmcv.imread(filename)
height, width = image.shape[:2]
data_info = dict(filename=f'{image_id}.png', width=width, height=height)
# load annotations
label_prefix = self.img_prefix.replace('image_2', 'label_2')
lines = mmcv.list_from_file(osp.join(label_prefix, f'{image_id}.txt'))
content = [line.strip().split(' ') for line in lines]
bbox_names = [x[0] for x in content]
bboxes = [[float(info) for info in x[4:8]] for x in content]
gt_bboxes = []
gt_labels = []
gt_bboxes_ignore = []
gt_labels_ignore = []
# filter 'DontCare'
for bbox_name, bbox in zip(bbox_names, bboxes):
if bbox_name in cat2label:
gt_labels.append(cat2label[bbox_name])
gt_bboxes.append(bbox)
else:
gt_labels_ignore.append(-1)
gt_bboxes_ignore.append(bbox)
data_anno = dict(
bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
labels=np.array(gt_labels, dtype=np.long),
bboxes_ignore=np.array(gt_bboxes_ignore,
dtype=np.float32).reshape(-1, 4),
labels_ignore=np.array(gt_labels_ignore, dtype=np.long))
data_info.update(ann=data_anno)
data_infos.append(data_info)
return data_infos
3.2 修改mmdet/datasets/__init__.py,修改位置已注释标出
from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
from .cityscapes import CityscapesDataset
from .coco import CocoDataset
from .custom import CustomDataset
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
RepeatDataset)
from .deepfashion import DeepFashionDataset
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
from .utils import replace_ImageToTensor
from .voc import VOCDataset
from .wider_face import WIDERFaceDataset
from .xml_style import XMLDataset
from .kitti import KittiDataset #新加
__all__ = [
#下面的KittiDataset为新加
'KittiDataset','CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
'build_dataset', 'replace_ImageToTensor'
]
3.3 configs/_base_/datasets中添加kitti_detection.py,内容如下
# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=