pytorch载入模型时出现错误AttributeError Generator object has no attribute copy

本文详细解析了PyTorch中两种模型保存与加载的方法:一是仅保存模型参数,适用于相同模型结构;二是保存整个模型,包括网络结构和参数。通过实例对比,纠正了常见错误操作,并提供了正确的实现方式。
部署运行你感兴趣的模型镜像
错误原因

保存的网络模型,但是读取的时候按照读取参数的方式使用了model.load_state_dict(),如下

model = Generator(3, 3, 64)
model_weight_path = ''
model.load_state_dict(torch.load(mdoel_weight_path))
model.eval
正确方式
model_weight_path = "./generator.pkl"
model = torch.load(model_weight_path)
model.eval()
总结
  • 方法一
    torch.save(net.state_dict(),path)
    功能:保存训练完的网络的各层参数(即weights和bias)
    其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)
    net2.load_state_dict(torch.load(path))
    功能:加载保存到path中的各层参数到神经网络
    注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

  • 方法二
    torch.save(net,path)
    功能:保存训练完的整个网络模型(不止weights和bias)
    net2=torch.load(path)
    功能:加载保存到path中的整个神经网络

  • 附:基于pytorch的保存和加载模型参数的方法

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论 1
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值