3步搞定视觉模型特征提取:pytorch-image-models中间层输出与可视化指南

3步搞定视觉模型特征提取:pytorch-image-models中间层输出与可视化指南

【免费下载链接】pytorch-image-models huggingface/pytorch-image-models: 是一个由 Hugging Face 开发维护的 PyTorch 视觉模型库,包含多个高性能的预训练模型,适用于图像识别、分类等视觉任务。 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

你是否还在为获取深度学习模型的中间层特征而烦恼?是否因不同模型接口不统一而重复编写代码?本文将带你使用pytorch-image-models(timm库)的统一接口,3步实现任意预训练模型的特征提取与可视化,无需复杂的模型修改。

读完本文你将掌握:

  • 终极层特征(Pre-Classifier Features)的3种提取方法
  • 多尺度特征图(Feature Pyramid)的高效获取技巧
  • 中间层特征可视化的实用代码模板

1. 核心概念:特征提取的两种主流方式

pytorch-image-models(timm)提供了统一的特征提取机制,支持所有模型的中间层特征获取。主要分为两类应用场景:

1.1 终极层特征(Pre-Classifier Features)

指分类器之前的特征层,适用于迁移学习、特征匹配等任务。根据是否经过全局池化(Global Pooling),分为:

  • 未池化特征(Unpooled):保留空间维度的特征图(如[2, 2048, 7, 7])
  • 池化特征(Pooled):压缩为向量形式的特征(如[2, 2048])

官方文档详细说明:hfdocs/source/feature_extraction.mdx

1.2 多尺度特征图(Feature Pyramid)

适用于目标检测、语义分割等密集预测任务,可同时获取不同层级的特征图。timm通过features_only=True参数实现,支持1017种模型架构(hfdocs/source/changes.mdx)。

2. 实操步骤:从代码到可视化

2.1 终极层特征提取

方法1:使用forward_features()接口

无需修改模型结构,直接调用专用接口获取未池化特征:

import torch
import timm
m = timm.create_model('xception41', pretrained=True)
input = torch.randn(2, 3, 299, 299)  # 符合模型输入尺寸
features = m.forward_features(input)
print(f'未池化特征形状: {features.shape}')  # torch.Size([2, 2048, 10, 10])
方法2:创建无分类器模型

初始化时移除分类头和池化层,直接输出未池化特征:

m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
features = m(torch.randn(2, 3, 224, 224))
print(f'未池化特征形状: {features.shape}')  # torch.Size([2, 2048, 7, 7])
方法3:动态移除分类器

对已创建的模型,使用reset_classifier()方法移除分类头:

m = timm.create_model('densenet121', pretrained=True)
m.reset_classifier(0, '')  # 参数0表示移除分类器,空字符串表示移除池化
features = m(torch.randn(2, 3, 224, 224))
print(f'未池化特征形状: {features.shape}')  # torch.Size([2, 1024, 7, 7])

若需池化特征,只需保留默认池化设置:

m = timm.create_model('resnet50', pretrained=True, num_classes=0)  # 仅移除分类器
features = m(torch.randn(2, 3, 224, 224))
print(f'池化特征形状: {features.shape}')  # torch.Size([2, 2048])

2.2 多尺度特征图提取

适用于目标检测、分割等需要多尺度特征的任务。timm通过features_only=True参数统一接口,支持所有模型输出多尺度特征金字塔。

基础用法
m = timm.create_model('resnest26d', features_only=True, pretrained=True)
features = m(torch.randn(2, 3, 224, 224))
for i, feat in enumerate(features):
    print(f'特征层级{i}: {feat.shape}')

输出结果:

特征层级0: torch.Size([2, 64, 112, 112])  # 1/2分辨率
特征层级1: torch.Size([2, 256, 56, 56])   # 1/4分辨率
特征层级2: torch.Size([2, 512, 28, 28])   # 1/8分辨率
特征层级3: torch.Size([2, 1024, 14, 14])  # 1/16分辨率
特征层级4: torch.Size([2, 2048, 7, 7])    # 1/32分辨率
高级配置

通过out_indices参数选择特定层级,output_stride参数控制输出步长:

m = timm.create_model(
    'regnety_032', 
    features_only=True, 
    pretrained=True,
    out_indices=(1, 3),  # 选择第2和第4个特征层
    output_stride=16     # 限制最大步长为16
)
print(f'特征通道数: {m.feature_info.channels()}')  # [72, 576]
features = m(torch.randn(2, 3, 224, 224))
for feat in features:
    print(f'特征形状: {feat.shape}')

2.3 特征可视化实现

以下是完整的特征提取与可视化代码,以ResNet50为例可视化第3层级特征图:

import torch
import timm
import matplotlib.pyplot as plt
import numpy as np

# 1. 创建特征提取模型
model = timm.create_model(
    'resnet50', 
    features_only=True, 
    pretrained=True,
    out_indices=(3,)  # 选择第4个特征层(1/16分辨率)
)
model.eval()

# 2. 预处理输入图像
from PIL import Image
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# 加载图像并转换
img = Image.open("test_image.jpg").convert('RGB')
input_tensor = transform(img).unsqueeze(0)

# 3. 提取特征
with torch.no_grad():
    features = model(input_tensor)[0]  # 取第一个(也是唯一一个)特征层

# 4. 可视化前16个通道
features = features.squeeze(0).cpu().numpy()
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    if i < features.shape[0]:
        ax.imshow(features[i], cmap='viridis')
        ax.set_title(f'Channel {i}')
    ax.axis('off')
plt.tight_layout()
plt.savefig('feature_visualization.png')

3. 进阶技巧与最佳实践

3.1 特征信息查询

所有特征提取模型都有feature_info属性,可查询特征通道数、步长等信息:

m = timm.create_model('regnety_032', features_only=True, pretrained=True)
print(f'特征通道数: {m.feature_info.channels()}')  # [32, 72, 216, 576, 1512]
print(f'特征步长: {m.feature_info.reduction()}')    # [2, 4, 8, 16, 32]

3.2 模型剪枝优化

使用prune_intermediate_layers()方法移除不需要的层,减少计算量:

model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('原始参数: ', sum(p.numel() for p in model.parameters()))  # 38880232

# 剪枝最后一个块、分类头和归一化层
indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True)
print('剪枝后参数: ', sum(p.numel() for p in model.parameters()))  # 35212800

# 仅返回指定中间层特征
intermediates = model.forward_intermediates(input_tensor, indices=indices, intermediates_only=True)

3.3 性能对比

不同模型的特征提取速度对比(基于RTX 3090,批次大小16):

模型输入尺寸特征提取时间(ms)特征维度
resnet50224x22412.32048
efficientnet_b0224x2248.71280
vit_base_patch16_224224x22415.6768
convnext_tiny224x22410.2768

数据来源:results/benchmark-infer-amp-nchw-pt240-cu124-rtx3090.csv

总结与展望

本文介绍了pytorch-image-models的特征提取功能,通过统一接口实现了3种终极层特征提取方法和多尺度特征图获取,并提供了可视化代码模板。关键优势:

  1. 接口统一:所有模型使用相同参数(features_only=True)提取特征
  2. 零代码修改:无需手动修改模型结构,通过API直接获取特征
  3. 丰富元信息feature_info属性提供完整的特征描述

timm库持续更新,目前1017种模型中的1017种已支持特征提取(hfdocs/source/changes.mdx)。下一步可探索特征融合、自注意力可视化等高级应用。

收藏本文,下次需要特征提取时直接取用代码模板!关注获取更多计算机视觉实战技巧。

【免费下载链接】pytorch-image-models huggingface/pytorch-image-models: 是一个由 Hugging Face 开发维护的 PyTorch 视觉模型库,包含多个高性能的预训练模型,适用于图像识别、分类等视觉任务。 【免费下载链接】pytorch-image-models 项目地址: https://gitcode.com/GitHub_Trending/py/pytorch-image-models

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值