26.模型文件的读写

您已经很好地概述了PyTorch中模型保存和加载的几种方法。下面我将用代码示例来详细解释这些方法,并给出一些建议。

方法一:保存和加载模型参数(推荐)

import torch  
import torch.nn as nn  
  
# 假设我们有一个简单的MLP模型  
class MLP(nn.Module):  
    def __init__(self):  
        super(MLP, self).__init__()  
        self.fc1 = nn.Linear(10, 20)  
        self.fc2 = nn.Linear(20, 1)  
  
    def forward(self, x):  
        x = torch.relu(self.fc1(x))  
        x = self.fc2(x)  
        return x  
  
# 实例化模型  
model = MLP()  
  
# 保存模型参数  
torch.save(model.state_dict(), 'model/MLP_state_dict.pth')  
  
# 加载模型参数  
model_loaded = MLP()  # 需要重新实例化模型结构  
model_loaded.load_state_dict(torch.load('model/MLP_state_dict.pth'))  
  
# 现在model_loaded和之前训练的model有相同的参数

方法二:保存和加载整个模型(不推荐)

# 保存整个模型  
torch.save(model, 'model/MLP_whole_model.pth')  
  
# 加载整个模型  
model_loaded = torch.load('model/MLP_whole_model.pth')  
  
# 注意:这种方法不常用,因为它依赖于模型定义的确切位置,并且不灵活

方法三:保存和加载检查点(Checkpoint,推荐)

检查点通常用于保存训练过程中的状态,包括模型参数、优化器状态、训练步数等。

# 假设我们有一个优化器和一个训练步数  
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  
epoch = 0  
  
# 在某个训练步骤后保存检查点  
checkpoint = {  
    'model': model.state_dict(),  
    'optimizer': optimizer.state_dict(),  
    'epoch': epoch,  
    'loss': some_loss_value,  # 如果需要的话  
}  
torch.save(checkpoint, 'model/checkpoint.pth')  
  
# 加载检查点  
checkpoint = torch.load('model/checkpoint.pth')  
model = MLP()  
model.load_state_dict(checkpoint['model'])  
optimizer.load_state_dict(checkpoint['optimizer'])  
epoch = checkpoint['epoch']  
loss = checkpoint['loss']  
  
# 现在可以继续从上一个检查点开始训练

注意事项

  • 在保存和加载模型时,请确保模型架构在保存和加载时都是一致的。
  • 对于方法三,你可能还需要在代码中恢复任何其他你可能需要的状态(例如损失值、学习率调度器等)。
  • 使用.pt.pth作为模型文件的扩展名是PyTorch的约定,但这不是强制的。
  • 当保存和加载模型时,请确保PyTorch的版本是一致的,或者至少是兼容的,以避免可能的错误。
  • 在生产环境中部署模型时,通常只保存和加载模型参数(方法一),因为这样可以更灵活地与不同的前端或后端集成。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值