MMPreTrain模型推理与特征提取完全指南
前言
MMPreTrain作为开源深度学习项目,提供了丰富的预训练模型和便捷的推理接口。本文将详细介绍如何使用MMPreTrain进行模型推理、特征提取等操作,帮助开发者快速上手这一强大的计算机视觉工具库。
模型查询与获取
查询可用模型
MMPreTrain内置了大量预训练模型,我们可以通过list_models()
函数查看所有可用模型:
from mmpretrain import list_models
# 列出所有模型
all_models = list_models()
print(all_models[:5]) # 打印前5个模型
# 使用通配符查询特定模型
convnext_models = list_models("*convnext-b*21k")
对于特定任务,我们还可以通过对应推理器的list_models()
方法查询支持的模型:
from mmpretrain import ImageCaptionInferencer
caption_models = ImageCaptionInferencer.list_models()
获取模型实例
获取模型实例是使用MMPreTrain的第一步,get_model()
函数提供了灵活的模型获取方式:
from mmpretrain import get_model
# 基本用法 - 获取不带预训练权重的模型
model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
# 加载默认预训练权重
pretrained_model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
# 自定义模型结构 - 修改分类头
custom_model = get_model("convnext-base_in21k-pre_3rdparty_in1k",
head=dict(num_classes=10))
# 获取无头部的特征提取模型
feature_model = get_model("resnet18_8xb32_in1k",
head=None, neck=None,
backbone=dict(out_indices=(1, 2, 3)))
获取的模型是标准的PyTorch模块,可以直接用于推理:
import torch
x = torch.rand((1, 3, 224, 224))
y = model(x)
图像推理实践
快速单图推理
对于简单的单图推理任务,可以使用inference_model()
快捷函数:
from mmpretrain import inference_model
image_path = "demo/demo.JPEG"
result = inference_model('resnet50_8xb32_in1k', image_path, show=True)
print(f"预测类别: {result['pred_class']}")
print(f"预测分数: {result['pred_score']:.4f}")
批量推理与高级配置
对于更复杂的场景,建议使用专门的推理器(Inferencer):
from mmpretrain import ImageClassificationInferencer
# 初始化推理器
inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
# 单图推理
result = inferencer(image_path)[0] # 注意返回的是列表
# 批量推理
image_list = ['demo/demo.JPEG', 'demo/bird.JPEG'] * 16
results = inferencer(image_list, batch_size=8) # 支持批量处理
推理器支持丰富的配置选项:
# 自定义配置和权重
config = 'configs/resnet/resnet50_8xb32_in1k.py'
checkpoint = 'resnet50_8xb32_in1k_20210831-ea4938fc.pth'
inferencer = ImageClassificationInferencer(
model=config,
pretrained=checkpoint,
device='cuda'
)
推理结果解析
MMPreTrain的推理结果通常包含以下信息:
{
"pred_label": 65, # 预测类别ID
"pred_score": 0.6649, # 预测最高分数
"pred_class": "sea snake", # 预测类别名称
"pred_scores": [..., 0.6649, ...] # 所有类别分数
}
特征提取实战
MMPreTrain提供了专门的FeatureExtractor
用于从图像中提取特征:
from mmpretrain import FeatureExtractor, get_model
# 获取多尺度特征输出的模型
model = get_model('resnet50_8xb32_in1k',
backbone=dict(out_indices=(0, 1, 2, 3)))
# 创建特征提取器
extractor = FeatureExtractor(model)
# 提取特征
features = extractor(image_path)[0] # 返回多尺度特征列表
# 查看各层特征维度
for i, feat in enumerate(features):
print(f"第{i}层特征维度: {feat.shape}")
与直接使用model.extract_feat()
不同,FeatureExtractor
直接处理图像文件,而非张量,更加方便实用。
多任务推理支持
MMPreTrain支持多种视觉任务,每种任务都有对应的推理器:
- 图像分类:
ImageClassificationInferencer
- 图像检索:
ImageRetrievalInferencer
- 图像描述生成:
ImageCaptionInferencer
- 视觉问答:
VisualQuestionAnsweringInferencer
- 视觉定位:
VisualGroundingInferencer
- 文本到图像检索:
TextToImageRetrievalInferencer
- 图像到文本检索:
ImageToTextRetrievalInferencer
- 视觉推理:
NLVRInferencer
可视化交互界面
MMPreTrain还提供了基于Gradio的可视化交互界面,方便非开发者用户体验模型效果。安装Gradio后即可启动:
pip install -U gradio
python projects/gradio_demo/launch.py
该界面集成了所有支持的视觉任务,用户可以通过简单的点击操作体验模型效果。
总结
MMPreTrain提供了从模型查询、获取到推理、特征提取的完整工具链。通过本文介绍的各种API,开发者可以:
- 快速查询和获取所需模型
- 进行单图或批量推理
- 提取图像特征用于下游任务
- 构建可视化演示界面
无论是研究还是生产环境,MMPreTrain都能提供高效、灵活的计算机视觉解决方案。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考