最近要使用一下,姑且记录一下
安装
pip install timm
作为特征提取
import torch
import timm
m = timm.create_model('resnet50', pretrained=True, num_classes=0)
o = m(torch.randn(2, 3, 224, 224))
print(f'Pooled shape: {o.shape}')
VIT
import timm
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.eval()