pytorch载入模型时出现错误AttributeError Generator object has no attribute copy

错误原因

保存的网络模型,但是读取的时候按照读取参数的方式使用了model.load_state_dict(),如下

model = Generator(3, 3, 64)
model_weight_path = ''
model.load_state_dict(torch.load(mdoel_weight_path))
model.eval
正确方式
model_weight_path = "./generator.pkl"
model = torch.load(model_weight_path)
model.eval()
总结
  • 方法一
    torch.save(net.state_dict(),path)
    功能:保存训练完的网络的各层参数(即weights和bias)
    其中:net.state_dict()获取各层参数,path是文件存放路径(通常保存文件格式为.pt或.pth)
    net2.load_state_dict(torch.load(path))
    功能:加载保存到path中的各层参数到神经网络
    注意:不可以直接为torch.load_state_dict(path),此函数不能直接接收字符串类型参数

  • 方法二
    torch.save(net,path)
    功能:保存训练完的整个网络模型(不止weights和bias)
    net2=torch.load(path)
    功能:加载保存到path中的整个神经网络

  • 附:基于pytorch的保存和加载模型参数的方法

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值