Pytorch保存和加载模型的两种方式

本文介绍了Pytorch中两种常见的模型保存方式:只保存模型参数和保存完整模型,并对比了两种方式的优缺点,适用于复杂模型的存储和复用。

与Tensorflow、Keras等框架一样,Pytorch也提供了两种保存模型的方式,这两种方式都是通过调用pickle序列化方法实现的:

1、只保存模型参数

2、保存完整模型

下面我们依次对这两种方式进行实现,以以下多层感知机模型为例:

def create_net():
    net = nn.Sequential()
    net.add_module('linear1', nn.Linear(15, 20))
    net.add_module('relu1', nn.ReLU())
    net.add_module('linear2', nn.Linear(20, 15))
    net.add_module('relu2', nn.ReLU())
    net.add_module('linear3', nn.Linear(15, 1))
    net.add_module('sigmoid', nn.Sigmoid())
    return net


net = create_net()

模型的基本信息如下:

 

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                   [-1, 20]             320
              ReLU-2                   [-1, 20]               0
            Linear-3                   [-1, 15]             315
              ReLU-4                   [-1, 15]               0
            Linear-5                    [-1, 1]              16
           Sigmoid-6                    [-1, 1]               0
================================================================
Total params: 651
Trainable params: 651
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.000057
Forward/backward pass size (MB): 0.000549
Params size (MB): 0.002483
Estimated Total Size (MB): 0.003090
----------------------------------------------------------------
None

1、只保存模型参数

保存模型参数:

torch.save(net.state_dict(), 'net_parameter.pkl')

 加载模型参数:

# 定义模型结构
net_clone = create_net()
# 加载模型参数
net_clone.load_state_dict(torch.load('net_parameter.pkl'))

 而后我们就可以微调模型或调用模型进行预测。

2、保存完整模型

保存完整模型指既保存模型结构又保存模型参数:

torch.save(net, 'net_model.pkl')

加载模型:

net_loaded = torch.load('net_model.pkl')

从上图我们可以看到,只保存模型参数的文件大小为4kb,而保存整个模型的文件大小为12kb,对于一些复杂的模型而言,保存整个模型可能需要占用大量的存储空间。所以推荐保存模型参数,在实际预测时再重新定义模型结构即可。

<think>好的,我现在需要回答用户关于PyTorch保存加载模型的方法及最佳实践的问题。首先,我应该回顾用户提供的引用内容,确保答案准确且符合引用中的信息。 用户提供了五个引用,其中引用1提到保存整个模型保存参数的优缺点,引用23分别展示了保存state_dict整个模型的代码示例。引用4强调保存state_dict是最佳实践,而引用5则展示了实际训练测试中加载模型的例子。 首先,我需要整理PyTorch保存模型两种主要方法:保存整个模型保存参数。然后比较它们的优缺点,引用14都提到保存参数更优,因为文件更小且更灵活。接下来,需要给出具体的代码示例,这部分可以参考引用2、35中的代码。例如,引用2展示了保存state_dict的代码,引用3保存整个模型,引用5则演示了加载模型进行预测的流程。 另外,需要注意最佳实践,比如在加载模型时可能需要先重建模型结构,再加载参数,这点在引用45中都有涉及。同时,用户可能关心如何继续训练或部署模型,需要提到这些场景下的注意事项,如确保模型结构参数一致,使用eval模式等。 还要检查是否有遗漏的信息,例如模型转换到设备(如GPU)的问题,这在引用5的代码中提到了将模型移动到设备上。此外,保存优化器状态训练状态对于继续训练很重要,但用户的问题可能更关注基础方法,所以需要简要提及高级用法。 最后,生成相关问题时要围绕保存加载方法、最佳实践、选择依据以及实际应用中的问题,确保问题有针对性,帮助用户深入理解。 总结步骤: 1. 介绍两种保存方法及其区别。 2. 分别给出代码示例,引用对应的引用来源。 3. 讨论最佳实践,如优先使用state_dict。 4. 提及继续训练部署时的注意事项。 5. 生成相关问题,覆盖方法对比、选择原因、优化实际应用。</think>### PyTorch模型保存加载方法及最佳实践 #### 一、两种核心保存方法 1. **保存整个模型** 使用`torch.save(model, 'model.pth')`直接序列化模型对象,包含网络结构参数[^3]。 **优点**:代码简单,加载时无需重建模型结构。 **缺点**:文件体积大,依赖原始类定义,移植性差[^1]。 2. **仅保存模型参数** 通过`model.state_dict()`获取参数字典后保存: ```python torch.save(model.state_dict(), 'model_weights.pth') # 引用[2] ``` **优点**:文件小巧(仅参数),移植性强(可跨代码版本使用)[^1][^4]。 **缺点**:加载时需要先实例化模型结构。 #### 二、加载模型的对应方法 | 保存方式 | 加载代码 | 适用场景 | |-----------------|--------------------------------------------------------------------------|--------------------| | 完整模型 | `model = torch.load('model.pth')` | 快速实验 | | 仅参数 | `model.load_state_dict(torch.load('model_weights.pth'))` | 生产部署、迁移学习 | #### 三、最佳实践指南 1. **优先选择`state_dict`方式** 官方推荐保存参数而非完整模型,因其更灵活且避免代码耦合[^4]。例如: ```python # 保存 torch.save({ 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict() # 可同时保存优化器状态 }, 'checkpoint.pth') # 加载 checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state']) optimizer.load_state_dict(checkpoint['optimizer_state']) ``` 2. **设备兼容性处理** 加载时指定`map_location`参数确保跨设备兼容: ```python model.load_state_dict(torch.load('model.pth', map_location='cuda:0')) ``` 3. **部署时的注意事项** - 加载后调用`model.eval()`关闭dropout等训练模式层 - 使用`torch.jit.trace`或`torch.jit.script`转换为TorchScript增强部署兼容性 #### 四、典型错误规避 1. **结构不匹配** ```python # 错误:加载参数前未正确初始化模型 new_model = SimpleModel(input_size=20) # 结构改变导致参数维度不匹配 new_model.load_state_dict(torch.load('old_model_weights.pth')) # 报错 ``` 2. **未处理设备映射** 在CPU训练的模型直接加载到GPU环境可能导致张量设备不一致。 #### 五、性能优化技巧 1. **半精度存储** ```python torch.save(model.half().state_dict(), 'model_fp16.pth') # 减少50%存储空间 ``` 2. **模型剪枝后保存** 通过移除冗余参数提升推理速度,剪枝后的参数需重新保存。 相关问题
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值