CenterNet 训练自己的数据集

github地址:https://github.com/Duankaiwen/CenterNet

论文:https://arxiv.org/abs/1904.08189

1、在github上下载,配置好环境。

2、准备数据

把自己的数据转化为coco的格式,网上有很多工具可以下载使用。我的数据是yolo格式的,需要的话可以提供我的yolo to coco的代码。yolo to coco数据转化小工具:https://github.com/surserrr/yolo_to_coco (随手star啦谢谢~~

数据分成train和val。 图片文件夹名字改为trainval2014和minival2014,放到CenterNet-master/data/coco/images中; json文件名为instances_trainval2014.json和instances_minival2014.json,放到CenterNet-master/data/coco/annotations中。

注意:如果你在训练之前,用coco数据集测试了模型,那么把CenterNet-master/cache/coco_minival2014.pkl删掉!如果没有训练过就忽略。(因为你在第一次运行的时候,代码会把coco数据的instances转化为它要用的格式,下一次用的时候就会直接读取。如果你没删掉,但你训练自己的数据集的时候,模型在val的时候会自动读取已经存在的coco的数据)

3、修改参数。

我数据集的类别只有1类,GPU1个。

必须修改的参数:

    1)models/CenterNet-52.py或者models/CenterNet-104

### 使用自定义数据集训练 CenterNet 模型 #### 准备环境和依赖项 要使用自定义数据集训练 CenterNet 模型,首先需要设置好开发环境并安装必要的库。假设已经有一个 Python 环境,在终端执行如下命令来创建虚拟环境并激活它: ```bash conda create -n centernet_env python=3.8 conda activate centernet_env ``` 接着安装 PyTorch 和 torchvision 库以及其他所需的包。 ```bash pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install git+https://github.com/facebookresearch/fvcore.git pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.9/index.html ``` #### 下载预训练权重 下载官方提供的预训练模型参数作为初始化起点可以加速收敛过程,并有助于提高最终性能。对于 CenterNet 来说,可以从公开资源获取预训练权重文件 `CenterNet2_R50_1x.pth` 并加载到项目中[^2]。 ```python import torch pretrained_weights = torch.load('models/CenterNet2_R50_1x.pth') ``` #### 数据集准备 构建适合于 CenterNet数据集非常重要。通常情况下,这涉及到收集图片及其相应的标签信息(即物体的位置)。这些标签应该按照特定格式保存下来,比如 COCO JSON 文件格式或简单的文本文件描述边界框位置等细节。如果现有的标注不符合预期,则可能还需要转换工具帮助调整至所需样式。 针对 CenterNet 训练数据集应当遵循一定的命名约定与路径规划,确保程序能够自动识别哪些图像是用于训练而哪些又是验证用途;同时也要注意检查所有图像尺寸是否统一,因为不一致可能导致错误发生。 #### 修改配置文件 根据实际需求编辑配置文件中的超参数设定,如批量大小(batch size)、学习率(learning rate),以及迭代次数(iterations)等等。此外还需指定输入图片分辨率(input resolution)、使用的骨干网络(backbone network)类型和其他任何影响算法行为的关键选项。 #### 开始训练流程 完成上述准备工作之后就可以调用 Detectron2 提供的功能接口启动正式的训练环节了。下面给出一段简化版代码片段展示如何操作: ```python from detectron2.engine import DefaultTrainer, default_argument_parser, launch from detectron2.config import get_cfg from detectron2.data.datasets import register_coco_instances if __name__ == "__main__": parser = default_argument_parser() args = parser.parse_args() cfg = get_cfg() cfg.merge_from_file("configs/COCO-Detection/centernet_r50_fpn_1x.yaml") # 加载默认配置 # 注册新的数据集 register_coco_instances("my_dataset_train", {}, "path/to/json/train.json", "path/to/images/") register_coco_instances("my_dataset_val", {}, "path/to/json/val.json", "path/to/images/") cfg.DATASETS.TRAIN = ("my_dataset_train", ) cfg.DATASETS.TEST = ("my_dataset_val", ) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False) trainer.train() ``` 通过以上步骤即可基于个人定制化的数据源对 CenterNet 进行有效的再训练工作。当然具体实施过程中还可能会遇到各种各样的挑战和技术难题,建议参考官方文档或其他社区资料寻求解决方案。
评论 49
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值