【mmcls】mmdet中使用mmcls的网络及预训练模型

本文介绍了如何在mmdetection中使用mmcls的预训练模型,包括查找预训练模型、获取网络参数和权重、确定特征输出维度,并提供了在mmdetection配置文件中应用这些模型的步骤。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

        mmcls现在叫mmpretrain,以前叫mmclassification,这里为了统一称为mmcls。在基于MM框架的下游任务,例如检测(mmdetection)中可以使用mmcls中的backbone进行特征提取,但这就需要知道网络的参数以及输出特征的维度。本文简单介绍了在mmdetection中使用mmcls中backbone的方法。mmdetection中需要配置backbone、模型权重及neck的特征维度等信息。

1 查找mmcls预训练模型

        查找mmcls支持的网络的方法有多种:

  1. 在mmpretrain的README中;
  2. 在modelzoo种查找模型库统计 — MMClassification 1.0.0rc6 文档 (mmpretrain.readthedocs.io)
  3. 直接看repo的configs目录下的列表

2 获取网络参数(配置)及预训练权重

        找到网络后还需要找到网络参数及预训练权重。以replknet为例,获取网络参数可以直接看mmpretrain/configs/replknet中的配置文件,例如replknet-31B_32xb64_in1k.py,但配置文件可能并没有直接写模型配置信息,而是依赖其他配置文件,如下图中的replknet-31B_in1k.py

         继续找到上述配置文件,可以看到网络配置:

        预训练权重可以在mmpretrain/configs/replknet下的README中找到,例如:

 https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k_20221118-54ed5c46.pth

        预训练权重也可以在modelzoo中查找:

 

3 获取特征输出维度

        首先在modelzoo中查到已有模型的名称,然后使用mmcls.get_model获取模型,输出指定层的特征维度。

import torch
from mmcls import get_model, inference_model

inputs = torch.rand(16, 3, 224, 224)

# 构建模型
model_name = 'replknet-31B_in21k-pre_3rdparty_in1k'
model = get_model(model_name, pretrained=False, backbone=dict(out_indices=(0, 1, 2, 3)))
# model = get_model(model_name, pretrained=False, backbone=dict(out_scales=(0, 1, 2, 3)))  # mvitv2

feats = model.extract_feat(inputs)
for feat in feats:
    print(feat.shape)

        可以看到输出为 [128, 256, 512, 1024]:

torch.Size([16, 128])
torch.Size([16, 256])
torch.Size([16, 512])
torch.Size([16, 1024])

4 mmdetection中使用

        在mmdetection中修改配置文件中backbone,预训练权重和neck中的in_channels等信息。同时应该注意网络的优化器配置的参数。

checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/replknet/replknet-31B_in21k-pre_3rdparty_in1k-384px_20221118-76c92b24.pth'  # noqa

model = dict(
    backbone=dict(
        _delete_=True,
        type='mmcls.RepLKNet',
        arch='31B',
        out_indices=[0, 1, 2, 3],
        init_cfg=dict(
            type='Pretrained', checkpoint=checkpoint_file,
            prefix='backbone.')),
    neck=dict(
        _delete_=True,
        type='mmdet.FPN',
        in_channels=[128, 256, 512, 1024],
        out_channels=256,
        num_outs=5))

MMClsMMDetection Classification)是一个基于PyTorch的图像分类库,它通常用于预训练模型并在ImageNet数据集上进行微调,然后应用于其他下游任务。MMPretrain则是其更广泛的多模态预训练框架的一部分,支持多种视觉模型。 如果你想使用MMCls来评估VGG19网络的分类性能,并找10张图片来做实验,你需要按照以下步骤操作: 1. **安装依赖**:首先确保你已经安装了`mmcls`和相关的依赖库,如`torchvision`等。 ```bash pip install mmcv mmdet mmclassification torchvision ``` 2. **加载预训练模型**:下载或从pretrained模型库中获取预训练的VGG19模型,由于VGG19不是MMDetection的默认模型,你可能需要自行下载模型权重。 ```python from mmcls.models import build_classifier model = build_classifier('vgg19') # 假设vgg19模型已适配到mmcls框架 model.load_from_pretrained('your_vgg19_weights_path') ``` 3. **准备测试数据**:准备10张标注好类别的图片,可以是JPEG、PNG或其他支持的格式。将它们放在一个文件夹下,并准备好对应的标签列表。 4. **图像预处理**:对测试图片进行必要的预处理,比如调整尺寸、归一化等,以便输入到模型中。 ```python from mmcls.datasets import ImageDataset test_dataset = ImageDataset(img_prefix='path/to/images', ann_file='path/to/labels.txt') ``` 5. **评估性能**:创建一个DataLoader,然后使用模型在测试数据上进行预测并计算准确率等指标。 ```python data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) model.eval() predictions = [] for img, _, target in data_loader: outputs = model(img) # 前向传播并得到预测结果 predictions.append(outputs.argmax(dim=1).cpu().numpy()) # 获取每个样本的预测类别 # 集成所有预测结果 all_predictions = np.concatenate(predictions) # 计算准确率或其他指标 accuracy = (all_predictions == test_labels).mean() # 假设test_labels存储了实际的标签 ```
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值