PyTorch模型保存与加载完全指南 - 基于pytorchTutorial项目实践
前言
在深度学习项目开发过程中,模型保存与加载是至关重要的环节。本文将基于pytorchTutorial项目中的实践案例,全面讲解PyTorch框架下模型保存与加载的各种方法及其适用场景。
模型保存的基本方法
PyTorch提供了两种主要的模型保存方式:
1. 完整模型保存(简易方式)
torch.save(model, PATH) # 保存整个模型
model = torch.load(PATH) # 加载整个模型
model.eval()
特点:
- 代码简单直观
- 保存了模型结构和参数
- 需要原始模型类定义在可访问的位置
缺点:
- 保存文件较大
- 对模型定义的修改可能导致加载失败
- 不推荐在生产环境中使用
2. 状态字典保存(推荐方式)
torch.save(model.state_dict(), PATH) # 只保存状态字典
# 加载时需要重新创建模型结构
model = Model(*args, **kwargs) # 先创建相同结构的模型
model.load_state_dict(torch.load(PATH)) # 再加载参数
model.eval()
特点:
- 只保存模型参数,文件较小
- 更灵活,可以加载到不同结构的模型中
- 是PyTorch官方推荐的方式
实战代码解析
示例模型定义
class Model(nn.Module):
def __init__(self, n_input_features):
super(Model, self).__init__()
self.linear = nn.Linear(n_input_features, 1)
def forward(self, x):
y_pred = torch.sigmoid(self.linear(x))
return y_pred
这是一个简单的逻辑回归模型,包含一个线性层和sigmoid激活函数。
完整模型保存与加载
model = Model(n_input_features=6)
FILE = "model.pth"
torch.save(model, FILE) # 保存完整模型
loaded_model = torch.load(FILE) # 加载完整模型
loaded_model.eval()
状态字典保存与加载
FILE = "model.pth"
torch.save(model.state_dict(), FILE) # 只保存状态字典
loaded_model = Model(n_input_features=6) # 重新创建模型
loaded_model.load_state_dict(torch.load(FILE)) # 加载状态字典
loaded_model.eval()
检查点(Checkpoint)保存
在实际训练过程中,我们通常需要保存训练状态,以便从中断处恢复训练。这包括:
- 模型参数
- 优化器状态
- 当前epoch数
# 创建检查点
checkpoint = {
"epoch": 90,
"model_state": model.state_dict(),
"optim_state": optimizer.state_dict()
}
torch.save(checkpoint, "checkpoint.pth")
# 加载检查点
model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
epoch = checkpoint['epoch']
设备相关注意事项
当模型在不同设备(CPU/GPU)间迁移时,需要特别注意:
1. GPU保存,CPU加载
# 保存
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)
# 加载
device = torch.device('cpu')
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
2. GPU保存,GPU加载
device = torch.device("cuda")
model.to(device)
torch.save(model.state_dict(), PATH)
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
3. CPU保存,GPU加载
torch.save(model.state_dict(), PATH)
device = torch.device("cuda")
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
model.to(device)
重要实践建议
-
模式切换:加载模型后,务必根据用途调用
model.eval()(评估模式)或model.train()(训练模式)eval()模式会关闭Dropout和BatchNorm层的训练时行为- 忘记切换模式会导致不一致的推理结果
-
输入数据设备:当使用GPU模型时,确保输入数据也在GPU上:
input = input.to(device) -
版本兼容性:PyTorch版本差异可能导致模型加载问题,建议在相同版本环境下保存和加载模型
总结
本文详细介绍了PyTorch中模型保存与加载的各种方法,从基本的完整模型保存到推荐的状态字典方式,再到训练检查点的保存与恢复,最后讲解了不同设备间的模型迁移注意事项。掌握这些技术对于深度学习项目的开发和部署至关重要,能够帮助开发者更高效地管理模型训练过程,实现训练中断恢复,以及模型的服务化部署。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



