CenterNet(Objects as points)开源代码:https://github.com/xingyizhou/CenterNet
源码的dataset结构如下:
datasets
|
|---dataset # 解析各数据集(CenterNet共用了下面的数据集)
|---coco.py # Coco数据集
|---coco_hp.py # Coco human pose
|---kitti.py # kitti
|---pascal.py # PascalVOC
|
|---sample # 针对不同的网络 提取所需数据
|---ctdet.py # CenterNet
|---ddd.py # 3D Detection
|---exdet.py # ExtremeNet
|---multi_pose.py #
|
|---data_factory.py # 整合dataset和sample,构建完整的pipeline
该结构这样设计的目的是拆分和精细化每一个步骤,看过论文的知道,CenterNet可以很好地在目标检测、3D检测、人体姿态等任务上迁移,所以作者这样设计datasets更方便我们随意结合,同时,如果我们想使用自己的数据集也会很方便。
下文是详细解释,只想直接用懒得细看请移步:https://blog.youkuaiyun.com/weixin_43509263/article/details/100799415
我的任务是目标检测,采用Ccco数据集,使用CenterNet,所以简化文件结构,保留如下:
datasets
|---dataset
|---coco.py
|---sample
|---ctdet.py # CenterNet
|---data_factory.py
'''
实际上,一般构建Dataset我们都会继承torch.utils.data.Dataset,
一般都会重写__init__ 、__getitem__ 和 __len__ 三个函数,
这里,__init__、__len__在dataset实现,而 __getitem__在sample中
'''
- dataset中coco.py解析coco数据集:
"""
对coco数据集进行解析
def __init__(self, opt, split): 解析数据集中各属性
def __len__(self): 返回样本数
def run_eval(self, results, save_dir): eval接口
\-- def save_results(self, results, save_dir): 保存结果
\-- def convert_eval_format(self, all_bboxes): 将自己的结果 转换成coco要求的验证格式
"""
import pycocotools.coco as coco
import pycocotools.cocoeval as COCOeval
import numpy as np
import json
import os
import torch.utils.data as data
class COCO(data.Dataset):
num_classes = 80
default_resolution = [512, 512]
mean = np.array([0.40789654, 0.44719302, 0.47026115],
dtype=np.float32).reshape(1, 1, 3)
std = np.array([0.28863828, 0.27408164, 0.27809835],
dtype=np.float32).reshape(1, 1, 3)
def __init__(self, opt, split):
'''
:param opt: opt是传入的参数对象,在opt.py中
:param split: train\val\test
'''
super(COCO, self).__init__()
## self.data_dir、img_dir、annot_dir
self.data_dir = os.path.join(opt.data_dir, 'coco')
self.img_dir = os.path.join(self.data_dir, '{}2017'.format(split))
if split == 'test':
self.annot_path = os.path.join(
self.data_dir, 'annotations', 'image_info_test-dec2017.json')
else:
self.annot_path = os.path.join(
self.data_dir, 'annotations',
'instances_{}2017.json').format(split)
''' ???????????????????/ '''
self.max_objs = 128
# 类别名 加上__background__共81个
self.class_name = [
'__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train