3步搞定视觉模型特征提取:pytorch-image-models中间层输出与可视化指南
你是否还在为获取深度学习模型的中间层特征而烦恼?是否因不同模型接口不统一而重复编写代码?本文将带你使用pytorch-image-models(timm库)的统一接口,3步实现任意预训练模型的特征提取与可视化,无需复杂的模型修改。
读完本文你将掌握:
- 终极层特征(Pre-Classifier Features)的3种提取方法
- 多尺度特征图(Feature Pyramid)的高效获取技巧
- 中间层特征可视化的实用代码模板
1. 核心概念:特征提取的两种主流方式
pytorch-image-models(timm)提供了统一的特征提取机制,支持所有模型的中间层特征获取。主要分为两类应用场景:
1.1 终极层特征(Pre-Classifier Features)
指分类器之前的特征层,适用于迁移学习、特征匹配等任务。根据是否经过全局池化(Global Pooling),分为:
- 未池化特征(Unpooled):保留空间维度的特征图(如[2, 2048, 7, 7])
- 池化特征(Pooled):压缩为向量形式的特征(如[2, 2048])
官方文档详细说明:hfdocs/source/feature_extraction.mdx
1.2 多尺度特征图(Feature Pyramid)
适用于目标检测、语义分割等密集预测任务,可同时获取不同层级的特征图。timm通过features_only=True参数实现,支持1017种模型架构(hfdocs/source/changes.mdx)。
2. 实操步骤:从代码到可视化
2.1 终极层特征提取
方法1:使用forward_features()接口
无需修改模型结构,直接调用专用接口获取未池化特征:
import torch
import timm
m = timm.create_model('xception41', pretrained=True)
input = torch.randn(2, 3, 299, 299) # 符合模型输入尺寸
features = m.forward_features(input)
print(f'未池化特征形状: {features.shape}') # torch.Size([2, 2048, 10, 10])
方法2:创建无分类器模型
初始化时移除分类头和池化层,直接输出未池化特征:
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
features = m(torch.randn(2, 3, 224, 224))
print(f'未池化特征形状: {features.shape}') # torch.Size([2, 2048, 7, 7])
方法3:动态移除分类器
对已创建的模型,使用reset_classifier()方法移除分类头:
m = timm.create_model('densenet121', pretrained=True)
m.reset_classifier(0, '') # 参数0表示移除分类器,空字符串表示移除池化
features = m(torch.randn(2, 3, 224, 224))
print(f'未池化特征形状: {features.shape}') # torch.Size([2, 1024, 7, 7])
若需池化特征,只需保留默认池化设置:
m = timm.create_model('resnet50', pretrained=True, num_classes=0) # 仅移除分类器
features = m(torch.randn(2, 3, 224, 224))
print(f'池化特征形状: {features.shape}') # torch.Size([2, 2048])
2.2 多尺度特征图提取
适用于目标检测、分割等需要多尺度特征的任务。timm通过features_only=True参数统一接口,支持所有模型输出多尺度特征金字塔。
基础用法
m = timm.create_model('resnest26d', features_only=True, pretrained=True)
features = m(torch.randn(2, 3, 224, 224))
for i, feat in enumerate(features):
print(f'特征层级{i}: {feat.shape}')
输出结果:
特征层级0: torch.Size([2, 64, 112, 112]) # 1/2分辨率
特征层级1: torch.Size([2, 256, 56, 56]) # 1/4分辨率
特征层级2: torch.Size([2, 512, 28, 28]) # 1/8分辨率
特征层级3: torch.Size([2, 1024, 14, 14]) # 1/16分辨率
特征层级4: torch.Size([2, 2048, 7, 7]) # 1/32分辨率
高级配置
通过out_indices参数选择特定层级,output_stride参数控制输出步长:
m = timm.create_model(
'regnety_032',
features_only=True,
pretrained=True,
out_indices=(1, 3), # 选择第2和第4个特征层
output_stride=16 # 限制最大步长为16
)
print(f'特征通道数: {m.feature_info.channels()}') # [72, 576]
features = m(torch.randn(2, 3, 224, 224))
for feat in features:
print(f'特征形状: {feat.shape}')
2.3 特征可视化实现
以下是完整的特征提取与可视化代码,以ResNet50为例可视化第3层级特征图:
import torch
import timm
import matplotlib.pyplot as plt
import numpy as np
# 1. 创建特征提取模型
model = timm.create_model(
'resnet50',
features_only=True,
pretrained=True,
out_indices=(3,) # 选择第4个特征层(1/16分辨率)
)
model.eval()
# 2. 预处理输入图像
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
# 加载图像并转换
img = Image.open("test_image.jpg").convert('RGB')
input_tensor = transform(img).unsqueeze(0)
# 3. 提取特征
with torch.no_grad():
features = model(input_tensor)[0] # 取第一个(也是唯一一个)特征层
# 4. 可视化前16个通道
features = features.squeeze(0).cpu().numpy()
fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
if i < features.shape[0]:
ax.imshow(features[i], cmap='viridis')
ax.set_title(f'Channel {i}')
ax.axis('off')
plt.tight_layout()
plt.savefig('feature_visualization.png')
3. 进阶技巧与最佳实践
3.1 特征信息查询
所有特征提取模型都有feature_info属性,可查询特征通道数、步长等信息:
m = timm.create_model('regnety_032', features_only=True, pretrained=True)
print(f'特征通道数: {m.feature_info.channels()}') # [32, 72, 216, 576, 1512]
print(f'特征步长: {m.feature_info.reduction()}') # [2, 4, 8, 16, 32]
3.2 模型剪枝优化
使用prune_intermediate_layers()方法移除不需要的层,减少计算量:
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
print('原始参数: ', sum(p.numel() for p in model.parameters())) # 38880232
# 剪枝最后一个块、分类头和归一化层
indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True)
print('剪枝后参数: ', sum(p.numel() for p in model.parameters())) # 35212800
# 仅返回指定中间层特征
intermediates = model.forward_intermediates(input_tensor, indices=indices, intermediates_only=True)
3.3 性能对比
不同模型的特征提取速度对比(基于RTX 3090,批次大小16):
| 模型 | 输入尺寸 | 特征提取时间(ms) | 特征维度 |
|---|---|---|---|
| resnet50 | 224x224 | 12.3 | 2048 |
| efficientnet_b0 | 224x224 | 8.7 | 1280 |
| vit_base_patch16_224 | 224x224 | 15.6 | 768 |
| convnext_tiny | 224x224 | 10.2 | 768 |
数据来源:results/benchmark-infer-amp-nchw-pt240-cu124-rtx3090.csv
总结与展望
本文介绍了pytorch-image-models的特征提取功能,通过统一接口实现了3种终极层特征提取方法和多尺度特征图获取,并提供了可视化代码模板。关键优势:
- 接口统一:所有模型使用相同参数(
features_only=True)提取特征 - 零代码修改:无需手动修改模型结构,通过API直接获取特征
- 丰富元信息:
feature_info属性提供完整的特征描述
timm库持续更新,目前1017种模型中的1017种已支持特征提取(hfdocs/source/changes.mdx)。下一步可探索特征融合、自注意力可视化等高级应用。
收藏本文,下次需要特征提取时直接取用代码模板!关注获取更多计算机视觉实战技巧。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



