我的AI之路(45)--使用自己的数据集训练CenterNet

本文详细介绍如何在Pytorch环境下配置并训练CenterNet模型,包括环境搭建、模型编译、数据集准备、训练过程及常见错误解决方法。

更新说明:

        作者的源码:  https://github.com/xingyizhou/CenterNet是基于pytorch0.4.1的(CUDA最高版本只能使用到CUDA9),如果想使用pytorch1.0以上版本以支持使用CUDA10.0或以上版本,可以在下载作者的源码后,其他步骤照做,只是把DCNv2的源码用https://github.com/CharlesShang/DCNv2这里的源码替换掉就可以了,已在Pytorch1.2+RTX2080TI+CUDA10.0环境里训练了多个模型没任何问题:


cd CenterNet/src/lib/models/networks
rm -r DCNv2
git clone https://github.com/CharlesShang/DCNv2.git
cd DCNv2
python setup.py build develop    

 

[原文]

     CenterNet是anchor-free类型网络,具有识别精度高且速度快的特点,根据作者的论文中列出的数据来看,指标综合考虑来看比较牛了:

     最后那个CenterNet-HG,也就是backbone使用的Hourglass-104网络的AP值只比FSAF低一点了(但是FSAF目前貌似还没有源码放出来),比YOLO序列和RCNN序列都强很多,虽然FPS自有7.8,但是对一般实时性要求不是很高的视频检测也够用了,所以拿来试试。

     首先下载作者的源码:  git clone https://github.com/xingyizhou/CenterNet.git,根据安装说明:https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md,环境和工具软件是:

      Ubuntu 16.04, with Anaconda Python 3.6 and PyTorch v0.4.1

他这个源码是使用的pytorch0.41版写的,由于pytorch0.41支持的CUDA最高版本是CUDA9,不支持我们的服务器上目前安装的CUDA10.0或CUDA10.1,我先是试了一下使用conda创建隔离环境后安装支持CUDA10的pytorch1.3或pytorch1.0.0,然后跑了一下,结果报错,说是CenterNet中有API是不支持的了(后面再说),但在公共服务器上又不好随便乱装CUDA(安装过CUDA的应该知道它的厉害,很能折腾人,装得不对服务器登录进不去、黑屏之类的问题让人三思),于是想到还是使用docker最好,首先到hub.docker.com上拉取个pytorch0.4.1+CUDA9.0的devel版镜像:

     docker pull pytorch/pytorch:0.4.1-cuda9-cudnn7-devel

然后运行创建实例(进入容器内部后默认的初始路径是/workspace,所以把下载了CenterNet源码的目录work_pytorch映射到/workspace,并预留端口12000的映射,以备后面有需要时对模型做server端封装调用,并带上ipc=host参数,以防止做多GPU分布式训练的过程中出现共享内存不足的错误):

      nvidia-docker run --ipc=host -d -it --name pytorch0.41 -v /home/fychen/AI/work_pytorch:/workspace -p 12000:12000 pytorch/pytorch:0.4.1-cuda9-cudnn7-devel bash

进入容器后,执行下面的修改(容器内的pytorch安装在/opt/conda路径下)把torch.nn.functional.py里1254行的torch.backends.cuddn.enabled改为False:

      sed -i "1254s/torch\.backends\.cudnn\.enabled/False/g" /opt/conda/lib/python3.6/site-packages/torch/nn/functional.py

然后,依次执行下面的命令安装pycocotools:

     git clone https://github.com/cocodataset/cocoapi.git
     cd cocoapi/PythonAPI
     pip install cython
     make
     python setup.py install --user

  再依次执行下面的命令完成CenterNet下面的部分代码的编译:    

     cd /workspace/CenterNet
     pip install -r requirements.txt

     cd src/lib/models/networks/DCNv2
     ./make.sh

     cd /workspace/CenterNet/src/lib/external
     make

再安装一些跑CenterNet需要的支持包(不安装这些包会报错):

      apt-get install libglib2.0-dev  libsm6  libxrender1  libxext6

 

然后下载对应的预训练模型,我要使用的backbone是hour-glass,模型训练后用来做物体检测,根据https://github.com/xingyizhou/CenterNet/blob/master/readme/MODEL_ZOO.md :

      下载第一行的ctdet_coco_hg模型即可,点击右边的model链接下载模型文件ctdet_coco_hg.pth,这里是从dr

### 使用自定义数据集训练 CenterNet 模型 #### 准备环境和依赖项 要使用自定义数据集训练 CenterNet 模型,首先需要设置好开发环境并安装必要的库。假设已经有一个 Python 环境,在终端执行如下命令来创建虚拟环境并激活它: ```bash conda create -n centernet_env python=3.8 conda activate centernet_env ``` 接着安装 PyTorch 和 torchvision 库以及其他所需的包。 ```bash pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install git+https://github.com/facebookresearch/fvcore.git pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.9/index.html ``` #### 下载预训练权重 下载官方提供的预训练模型参数作为初始化起点可以加速收敛过程,并有助于提高最终性能。对于 CenterNet 来说,可以从公开资源获取预训练权重文件 `CenterNet2_R50_1x.pth` 并加载到项目中[^2]。 ```python import torch pretrained_weights = torch.load('models/CenterNet2_R50_1x.pth') ``` #### 数据集准备 构建适合于 CenterNet数据集非常重要。通常情况下,这涉及到收集图片及其相应的标签信息(即物体的位置)。这些标签应该按照特定格式保存下来,比如 COCO JSON 文件格式或简单的文本文件描述边界框位置等细节。如果现有的标注不符合预期,则可能还需要转换工具帮助调整至所需样式。 针对 CenterNet 训练数据集应当遵循一定的命名约定与路径规划,确保程序能够自动识别哪些图像是用于训练而哪些又是验证用途;同时也要注意检查所有图像尺寸是否统一,因为不一致可能导致错误发生。 #### 修改配置文件 根据实际需求编辑配置文件中的超参数设定,如批量大小(batch size)、学习率(learning rate),以及迭代次数(iterations)等等。此外还需指定输入图片分辨率(input resolution)使用的骨干网络(backbone network)类型和其他任何影响算法行为的关键选项。 #### 开始训练流程 完成上述准备工作之后就可以调用 Detectron2 提供的功能接口启动正式的训练环节了。下面给出一段简化版代码片段展示如何操作: ```python from detectron2.engine import DefaultTrainer, default_argument_parser, launch from detectron2.config import get_cfg from detectron2.data.datasets import register_coco_instances if __name__ == "__main__": parser = default_argument_parser() args = parser.parse_args() cfg = get_cfg() cfg.merge_from_file("configs/COCO-Detection/centernet_r50_fpn_1x.yaml") # 加载默认配置 # 注册新的数据集 register_coco_instances("my_dataset_train", {}, "path/to/json/train.json", "path/to/images/") register_coco_instances("my_dataset_val", {}, "path/to/json/val.json", "path/to/images/") cfg.DATASETS.TRAIN = ("my_dataset_train", ) cfg.DATASETS.TEST = ("my_dataset_val", ) trainer = DefaultTrainer(cfg) trainer.resume_or_load(resume=False) trainer.train() ``` 通过以上步骤即可基于个人定制化的数据源对 CenterNet 进行有效的再训练工作。当然具体实施过程中还可能会遇到各种各样的挑战和技术难题,建议参考官方文档或其他社区资料寻求解决方案。
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Arnold-FY-Chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值