PyTorch模型保存与加载完全指南 - 基于pytorchTutorial项目实践

PyTorch模型保存与加载完全指南 - 基于pytorchTutorial项目实践

【免费下载链接】pytorchTutorial PyTorch Tutorials from my YouTube channel 【免费下载链接】pytorchTutorial 项目地址: https://gitcode.com/gh_mirrors/py/pytorchTutorial

前言

在深度学习项目开发过程中,模型保存与加载是至关重要的环节。本文将基于pytorchTutorial项目中的实践案例,全面讲解PyTorch框架下模型保存与加载的各种方法及其适用场景。

模型保存的基本方法

PyTorch提供了两种主要的模型保存方式:

1. 完整模型保存(简易方式)

torch.save(model, PATH)  # 保存整个模型
model = torch.load(PATH)  # 加载整个模型
model.eval()

特点:

  • 代码简单直观
  • 保存了模型结构和参数
  • 需要原始模型类定义在可访问的位置

缺点:

  • 保存文件较大
  • 对模型定义的修改可能导致加载失败
  • 不推荐在生产环境中使用

2. 状态字典保存(推荐方式)

torch.save(model.state_dict(), PATH)  # 只保存状态字典

# 加载时需要重新创建模型结构
model = Model(*args, **kwargs)  # 先创建相同结构的模型
model.load_state_dict(torch.load(PATH))  # 再加载参数
model.eval()

特点:

  • 只保存模型参数,文件较小
  • 更灵活,可以加载到不同结构的模型中
  • 是PyTorch官方推荐的方式

实战代码解析

示例模型定义

class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

这是一个简单的逻辑回归模型,包含一个线性层和sigmoid激活函数。

完整模型保存与加载

model = Model(n_input_features=6)
FILE = "model.pth"
torch.save(model, FILE)  # 保存完整模型

loaded_model = torch.load(FILE)  # 加载完整模型
loaded_model.eval()

状态字典保存与加载

FILE = "model.pth"
torch.save(model.state_dict(), FILE)  # 只保存状态字典

loaded_model = Model(n_input_features=6)  # 重新创建模型
loaded_model.load_state_dict(torch.load(FILE))  # 加载状态字典
loaded_model.eval()

检查点(Checkpoint)保存

在实际训练过程中,我们通常需要保存训练状态,以便从中断处恢复训练。这包括:

  • 模型参数
  • 优化器状态
  • 当前epoch数
# 创建检查点
checkpoint = {
    "epoch": 90,
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict()
}
torch.save(checkpoint, "checkpoint.pth")

# 加载检查点
model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)

checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
epoch = checkpoint['epoch']

设备相关注意事项

当模型在不同设备(CPU/GPU)间迁移时,需要特别注意:

1. GPU保存,CPU加载

# 保存
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)

# 加载
device = torch.device('cpu')
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))

2. GPU保存,GPU加载

device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)

model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

3. CPU保存,GPU加载

torch.save(model.state_dict(), PATH)

device = torch.device("cuda")
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)

重要实践建议

  1. 模式切换:加载模型后,务必根据用途调用model.eval()(评估模式)或model.train()(训练模式)

    • eval()模式会关闭Dropout和BatchNorm层的训练时行为
    • 忘记切换模式会导致不一致的推理结果
  2. 输入数据设备:当使用GPU模型时,确保输入数据也在GPU上:

    input = input.to(device)
    
  3. 版本兼容性:PyTorch版本差异可能导致模型加载问题,建议在相同版本环境下保存和加载模型

总结

本文详细介绍了PyTorch中模型保存与加载的各种方法,从基本的完整模型保存到推荐的状态字典方式,再到训练检查点的保存与恢复,最后讲解了不同设备间的模型迁移注意事项。掌握这些技术对于深度学习项目的开发和部署至关重要,能够帮助开发者更高效地管理模型训练过程,实现训练中断恢复,以及模型的服务化部署。

【免费下载链接】pytorchTutorial PyTorch Tutorials from my YouTube channel 【免费下载链接】pytorchTutorial 项目地址: https://gitcode.com/gh_mirrors/py/pytorchTutorial

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值