pytorch模型保存与加载

存储模型的文件格式:

  • pkl
  • pth

两者并无较大区别

模型保存:

# 函数原型:
torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

# 保存方式1,保存模型的结构和参数
torch.save(model,'model.pkl')

# 保存方式2,只保存模型的参数
torch.save(model.state_dict(),'model.pkl')

(主要用前两个参数)

  • obj:要保存的对象。
  • f:保存对象的文件路径或文件对象。
  • pickle_module:用于序列化对象的 pickle 模块,默认为 Python 的 pickle 模块。
  • pickle_protocol:pickle 协议的版本,默认为默认协议。
  • _use_new_zipfile_serialization:是否使用新的 zipfile 序列化,默认为 True。

模型读取:

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

  • f:要加载的文件路径或文件对象。
  • map_location:指定如何重新映射存储位置。可以是字符串(如 ‘cpu’ 或 ‘cuda:0’),也可以是一个函数。
  • pickle_module:用于反序列化对象的 pickle 模块,默认为 Python 的 pickle 模块。
  • **pickle_load_args:传递给 pickle_module.load() 的其他关键字参数。
# 读取方式1,对应保存方式1,读取模型的结构和参数
model = torch.load('model.pkl')

# 读取方式2,对应保存方式2,预先设定结构,然后读取参数
model = torchvision.models.vgg16(pretrained=False)
model.load_state_dict(torch.load("model.pkl"))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值