【pytorch】网络模型的保存与读取

Pytorch会把模型相关信息保存为一个字典结构的数据,以用于继续训练或者推理。

1. 模型保存

常见的模型保存方式(PyTorch的模型一般以.pt或者.pth文件格式保存。)

# 保存方式1,模型结构+参数
torch.save(model,"xxx.pth") #"xxx.pth"可为PATH

# 保存方式2,模型参数(官方推荐)
torch.save(model.state_dict(),"xxx.pth") 
#把model的状态保存成**字典**格式,只保存网络模型参数,对比较大的模型占用空间小

除了模型参数之外,torch还可以保存其他训练相关参数,例如学习率、优化器信息等如:

torch.save({'model':model.state_dict(),
			 'optimizer':optimizer.state_dict(),
			 'epoch':epoch_num
			 "global_step": step
			 },'xxx.pth') #'xxx.pth'也可以是PATH

也可以将字典单拎出来分开

state = {'model':model.state_dict(),
		'optimizer':optimizer.state_dict(),
		'epoch':epoch_num
		"global_step": step
			 }
torch.save(state, 'xxx.pth')
#或者
state_path = "./xx/xxx.pth"
torch.save(state, state_path)

2.模型加载

基本的加载方式如下:

# 加载单个数据字典
model.load_state_dict(torch.load("xxx.pth")) #"xxx.pth"可为PATH

加载用于推理的常规Checkpoint常包含其他训练相关参数,例如学习率、优化器信息等如:

#若保存方式为:
torch.save({'model':model.state_dict(),
			 'optimizer':optimizer.state_dict(),
			 'epoch':epoch_num
			 "global_step": step
			 }, checkpoint_path) #checkpoint_path可为文件路径如:xx/xxx.pth
##加载方式可以为:
checkpoint = torch.load(checkpoint_path)

model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']
global_step = checkpoint["global_step"]

此外,load提供了很多重载的功能,其可以把在GPU上训练的权重加载到CPU上跑:

torch.load('tensors.pt')
# 强制所有GPU张量加载到CPU中
torch.load('tensors.pt', map_location=lambda storage, loc: storage)  #或者model.load_state_dict(torch.load('model.pth', map_location='cpu'))
# 把所有的张量加载到GPU 1中
torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# 把张量从GPU 1 移动到 GPU 0
torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

参考:pytorch模型的保存和加载、checkpoint

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值