使用本地文件创建resnet50模型

使用本地文件创建resnet50模型

1. 问题分析

最近用到目标检测的模型DETR,但是在创建模型的时候却遇到模型无法创建的问题。本文记录一下解决该问题的过程。

检查原因发现是在创建模型的过程中,需要联网下载。

即便是我将facebook/detr-resnet-50的所有文件下载到本地,然后在from_pretrained时候指定本地的路径,仍然遇到了连接hf下载模型的问题。由于一些众所周知的原因,hf无法直接访问了,这就导致下载遇到了点问题。

需要下载的文件是resnet50的backbone,是没有被包含在detr的模型文件中的。
transformers中的modeling_detr.py中并没有给resnet的本地文件留入参,这就带来了很多不便。即便如此,我们还是可以在modeling_detr.py手动创建backbone,以避免联网下载。

2. 问题解决

首先还是需要先下载resnet50的权重(需科学上网):
https://huggingface.co/timm/resnet50.a1_in1k/tree/main

将这些文件放在目录(记作path_a)中。
然后修改transfomers模块中的源码transformers/models/detr/modeling_detr.py
大约340行:

    def __init__(self, config):
        super().__init__()

        self.config = config

        if config
### ResNet50 权重文件 (.pth) 下载方法 PyTorch 提供了官方支持的预训练权重文件,可以通过 `torch.hub` 或者直接访问 PyTorch 的模型仓库来下载 ResNet50 的 `.pth` 文件。以下是具体实现方式: #### 方法一:通过 `torch.hub` 加载并保存权重 可以利用 `torch.hub.load` 函数加载预训练的 ResNet50 模型,并将其权重保存为本地 `.pth` 文件。 ```python import torch # 使用 torch.hub 加载预训练的 ResNet50 模型 model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True) # 获取模型的状态字典 (state_dict),即权重参数 state_dict = model.state_dict() # 将状态字典保存为 .pth 文件 torch.save(state_dict, './resnet50_pretrained.pth') ``` 上述代码会将 ResNet50 的预训练权重保存到当前目录下的 `resnet50_pretrained.pth` 文件中[^1]。 --- #### 方法二:手动指定存储路径 如果已经知道 `.pth` 文件的具体存储位置,则可以直接读取该文件。例如,在某些情况下,`.pth` 文件可能已经被缓存到了默认路径下。对于 PyTorch,默认的缓存路径通常是 `$TORCH_HOME/hub/checkpoints/`,其中 `$TORCH_HOME` 默认指向用户主目录下的 `.cache/torch` 文件夹。 因此,ResNet50 的预训练权重文件通常位于以下路径之一: - Linux/MacOS: `/home/<username>/.cache/torch/hub/checkpoints/resnet50-<hash>.pth` - Windows: `C:\Users\<Username>\.cache\torch\hub\checkpoints\resnet50-<hash>.pth` 如果没有找到对应的文件,可以通过运行一次 `torch.hub.load` 自动下载并缓存权重文件[^2]。 --- #### 方法三:直接从官网下载 除了通过 Python 脚本自动下载外,还可以直接从 PyTorch 官方网站或其他可信资源下载 `.pth` 文件。例如,ResNet50 的预训练权重可以从以下链接获取(需确认最新版本号): - [https://download.pytorch.org/models/resnet50-19c8e357.pth](https://download.pytorch.org/models/resnet50-19c8e357.pth)[^3] 下载完成后,可使用如下代码加载权重文件: ```python import torch # 加载已下载的 .pth 文件 path_to_pth_file = './resnet50-19c8e357.pth' weights = torch.load(path_to_pth_file) # 创建 ResNet50 模型实例 from torchvision.models import resnet50 model = resnet50() # 加载权重至模型 model.load_state_dict(weights) ``` --- ### 关于 `.pth` 文件的内容解析 ResNet50 的 `.pth` 文件是一个包含键值对的字典对象,记录了模型各层的参数名称及其对应张量数据。例如,执行以下代码可以查看其基本结构和大小: ```python parameter = torch.load('./resnet50-19c8e357.pth') print(type(parameter)) # 输出:<class 'dict'> print(len(parameter)) # 输出:267 表明有 267 组键值对 ``` 这表明 `.pth` 文件中共包含了 267 个参数项,涵盖了所有卷积核、偏置向量以及批量归一化的统计信息等。 --- ### 总结 无论是通过脚本自动化下载还是手动获取,都可以轻松获得 ResNet50 的预训练权重文件。推荐优先采用 `torch.hub` 方式以简化操作流程,同时确保使用的权重来自官方渠道,具备较高的可靠性。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值