因公司需要,需要将mmdetection2.5版本训练的模型迁移到mmdetection1.0,两者因环境(torch版本,mmcv版本,核心代码实现上的区别),直接更改config存在问题,本文记录迁移过程
1、mmdetection2.5 模型保存
两个版本mmdet的torch 版本不相同(其中2.5版本为1.7,1.0版本为1.1),因此在mmdetection2.5上保存的模型,在低版本的torch中无法正确读取,因此首先要需要解决这个问题,解决方案是只保存state_dict,重点是需要设定_use_new_zipfile_serialization为False,否则低版本torch无法读取
import os
import torch
from mmdet.apis import init_detector
config_path = 'config_v2_5.py'
model_path = 'model_v2_5.pth'
save_path = 'mmdet_demo/transfer.pth'
model = init_detector(config_path, model_path)
state_dict = model.state_dict()
torch.save(state_dict,,_use_new_zipfile_serialization=False
如上述脚本可实现高版本模型保存state_dict, 这样把weights迁移到对应的低版本
2、制作config
为了简单实现需求,对比了两个版本config的异同,在2.5版本的基础上,修改得到1.0版本的config, 主
要更改的地方有四个方面,如下所示:
-
RPN相关
因2.x版本相较1.x版本进行了较大改进,尤其在RPN的组成上,核心代码一致,组合方式发生了变化。

-
bbox_roi_extractor相关
第二个变化为ROI相关的组合方式

-
bbox_head相关

-
test_cfg相关
主要注意使用nms中的参数名称

3、实现转化
实现转化的流程:
1、利用1.0的config进行模型初始化,得到对应的结构state_dict
2、获取2.5版本保存的state_dict
3、遍历1.0获取的state_dict,保存有差异的key值,存入两个版本的txt文件,如下图所示:

4、遍历1.0的state_dict,对于不同情况进行分别赋值
具体代码如下所示:
import torch
import mmcv
from mmdet.models import build_detector
from mmcv.runner import load_checkpoint
import warnings
from mmdet.core import get_classes
from mmcv import Config
from mmdet.apis import init_detector
def</

本文详细记录了如何将基于MMDetection2.5训练的模型迁移到1.0版本的过程,包括模型状态保存、配置文件调整以及不同版本state_dict的转换步骤,涉及关键代码示例,适用于处理模型版本升级问题。
最低0.47元/天 解锁文章
2417

被折叠的 条评论
为什么被折叠?



