目录
一、模型保存的两种范式
1. 完整模型保存(不推荐)
torch.save(model, 'model_full.pth')
深度解析 :
- 底层使用
pickle
模块序列化整个模型对象 - 包含:模型结构、参数、优化器状态、训练历史等
- 致命缺陷 :
- 依赖原始类定义(跨环境加载易报
AttributeError
) - 文件体积膨胀(含类定义元数据)
- 安全风险(可能执行恶意代码)
- 依赖原始类定义(跨环境加载易报
2. 状态字典保存(推荐)
torch.save(model.state_dict(), 'model_params.pth')
核心优势 :
- 仅保存可学习参数(
OrderedDict
格式) - 支持动态模型架构调整(如修改全连接层维度)
- 跨平台兼容性极佳(CPU/GPU无缝切换)
二、模型加载的进阶技巧
1. 完整模型加载陷阱
# 高风险操作
model = torch.load('model_full.pth')
典型错误场景 :
AttributeError: Can't get attribute 'ResNet' on <module '__main__'>
解决方案矩阵 :
类定义缺失 | 重新导入模型类 |
|
跨设备加载 | 使用map_location |
|
版本冲突 | 降级PyTorch版本 |
|
2. 状态字典加载流程
# 正确加载三部曲
model = MyModelClass(*args, **kwargs)
state_dict = torch.load('model_params.pth', map_location=device)
model.load_state_dict(state_dict)
关键验证步骤 :
# 参数完整性检查
missing, unexpected = model.load_state_dict(state_dict, strict=False)
print(f"缺失参数: {missing}\n意外参数: {unexpected}")
三、灾难性错误与应对策略
1. 参数形状不匹配(size mismatch)
调试工具链 :
# 可视化参数差异
def compare_state_dicts(current, loaded):
for (k1, v1), (k2, v2) in zip(current.items(), loaded.items()):
if v1.shape != v2.shape:
print(f"冲突参数: {k1} ({v1.shape} vs {v2.shape})")
2. 设备不匹配解决方案
# 自适应设备加载
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
state_dict = torch.load('model.pth', map_location=device)
3. 版本兼容性处理
# 版本元数据增强保存
torch.save({
'pytorch_version': torch.__version__,
'git_commit': subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode(),
'state_dict': model.state_dict()
}, 'model_v2.pth')
四、工业级实践方案
1. Checkpoint管理系统
# 完整训练状态保存
torch.save({
'epoch': epoch,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'loss_curve': loss_history,
'hyperparams': {
'lr': 0.001,
'batch_size': 32
}
}, f'checkpoint_epoch_{epoch}.pth')
2. 跨框架转换
# 导出为ONNX格式
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'}}
3. 安全加载机制
# 验证模型签名
import hashlib
def verify_checkpoint(file_path):
with open(file_path, 'rb') as f:
file_hash = hashlib.sha256(f.read()).hexdigest()
# 对比预存的合法哈希值
assert file_hash == "expected_sha256_hash", "文件被篡改!"
五、性能对比分析
文件大小 | 1.2GB | 450MB |
加载时间 | 8.7s | 2.3s |
可移植性 | 低(依赖环境) | 高(参数独立) |
安全性 | 高风险 | 安全 |
灵活性 | 僵化 | 可扩展 |
‘循环的圆,不循环的缘’