timm 和 torchvision 中的 resnet50

该博客展示了如何使用timm和torchvision库分别加载预训练的ResNet50模型,并将它们导出为ONNX格式。通过比较导出的ONNX模型,作者发现权重是相同的。这为在不同框架间转移模型提供了一种验证方法。
部署运行你感兴趣的模型镜像

从 timm 和 torchvision 分别加载 resnet50 预训练模型,

import torch
def export_onnx(model_saved, onnx_save_name, input_name='img', output_name='logits'):
    dummy_input = torch.randn(1, 3, 224, 224)
    dynamic_axes = dict()
    dynamic_axes[input_name] = {0:"batch_size"}
    dynamic_axes[output_name] = {0:"batch_size"}
    torch.onnx.export(model_saved, dummy_input, onnx_save_name,
        input_names=[input_name], output_names=[output_name],
        export_params=True, verbose=False, opset_version=12,
        dynamic_axes=dynamic_axes)
    
if __name__ == '__main__':
    import torchvision
    net = torchvision.models.resnet50(pretrained=True)
    export_onnx(net, './resnet50_torchvision.onnx')
    
    import timm
    net = timm.create_model('resnet50', pretrained=True)
    export_onnx(net, './resnet50_timm.onnx')

从 onnx 看,权重是一样的。

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.7

PyTorch 2.7

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

### 使用 `timm` 下载数据集的方法 `timm` 是一个专注于计算机视觉的高效工具库,虽然其主要功能集中在模型加载训练上[^3],但它也提供了一些用于处理数据集的实用接口。以下是关于如何使用 `timm` 加载或下载数据集的具体说明。 #### 数据集加载方式 `timm` 提供了 `IterableImageDataset` 类来支持大规模图像数据集的加载[^1]。需要注意的是,该类并不像传统的 PyTorch 数据集那样支持随机访问(即不支持通过索引获取样本)。如果尝试对其实例化对象执行类似 `dataset[0]` 的操作,则会抛出错误。 要正确使用 `IterableImageDataset`,可以通过迭代器的方式逐批读取数据: ```python from timm.data import IterableImageDataset from timm.data.parsers.parser_image_folder import ParserImageFolder root = './path_to_your_dataset' # 替换为实际的数据集路径 parser = ParserImageFolder(root) dataset = IterableImageDataset(root=root, parser=parser) for image, label in dataset: print(image.shape, label) ``` 上述代码展示了如何利用 `IterableImageDataset` 解析器 (`ParserImageFolder`) 来遍历文件夹中的图像数据集。 #### 预训练权重存储位置 当使用 `timm` 或其他 PyTorch 工具时,预训练模型的权重通常会被缓存在本地目录中。这些缓存的位置因操作系统不同而有所差异[^2]: - **Windows**: `C:\Users\[用户名]\.cache\torch\hub\checkpoints` - **Linux**: `~/.cache/torch/hub/checkpoints` - **macOS**: `~/Library/Caches/torch/hub/checkpoints` 尽管 `timm` 主要是用来加载预训练模型而非直接管理数据集下载,但了解这一机制有助于理解整个工作流。 #### 结合 ImageNet 进行推理的例子 下面是一个完整的例子,展示如何结合 `timm` 中的 ResNet-18 模型以及自定义数据集完成图像分类任务[^4]: ```python import torch from torchvision import transforms from PIL import Image from timm.models import create_model from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD # 创建模型并加载预训练参数 model = create_model('resnet18', pretrained=True) model.eval() # 定义转换函数 transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)]) # 打开测试图片 image_path = 'test.jpg' with open(image_path, 'rb') as f: img = Image.open(f).convert('RGB') # 转换输入张量 input_tensor = transform(img).unsqueeze(0) # 添加批次维度 # 推理过程 with torch.no_grad(): output = model(input_tensor) probabilities = torch.nn.functional.softmax(output[0], dim=0) top5_prob, top5_catid = torch.topk(probabilities, 5) print(top5_prob, top5_catid) ``` 此脚本实现了从加载模型到预测 Top-5 标签的过程。 ---
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值