pytorch保存和加载模型的两种方式

pytorch中保存和加载模型是绑在一起的。
这里我需要注意一下不同的保存方式对应不同的读取方式,两者各有利弊。

首先说说pytorch.save()这个函数,可以参考官网:pytroch.save
简而言之,这个函数可以保存任意的东西,比如tensor或者模型,或者仅仅是模型的参数。
如果将保存对象局限在模型上,通常来说我们有两种方式直接保存所有的模型,只保存模型中的参数(模型结构就保存了)。以下分别说说两种不同的方式。

为了说明,我们先建立一个简单的模型。

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, in_c, out_c, ngf=64):
        super(Generator, self).__init__()
        model = []
        model += [
            nn.Conv2d(in_c, ngf, 3, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(ngf),
            nn.Conv2d(ngf, out_c, 3, 2, 1)
        ]
        self.model = nn.Sequential(*model)
        
    def forward(self, x):
        return self.model(x)

netG = Generator(3, 3)
input = torch.zeros(10, 3, 256, 256)
output = netG(input)

直接保存所有模型并读取

直接使用简单粗暴的方式保存:

torch.save(netG, 'netG.pt')

对应的,我们可以这样读取模型

netC = torch.load('netG.pt')
input = torch.zeros(10, 3, 256, 256)
output = netC(input)

正常情况如下(警告先忽略):在这里插入图片描述

只保存模型中的参数并读取

我们说模型的参数保存在网络的state_dict中,使用这个就可以读取网络的参数了。

torch.save({'netG': netG.state_dict()}, 'model_test.pt')

对应的加载模型的方式如下:

netD = Generator(3, 3)
state_dict = torch.load('model_test.pt')
netD.load_state_dict(state_dict['netG'])
input = torch.zeros(10, 3, 256, 256)
output = netD(input)

总结

我们可以看到第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。而第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。

总的来说,我们一般来选择第二种来保存和读取
退一步讲,如何保存模型决定了如何读取模型

### Flink 大数据处理优化技巧与最佳实践 #### 调优原则与方法概述 对于Flink SQL作业中的大状态导致的反压问题,调优的核心在于减少状态大小以及提高状态访问效率。通过合理配置参数和调整逻辑设计可以有效缓解此类瓶颈[^1]。 #### 参数设置建议 针对不同版本下的具体特性差异,在实施任何性能改进措施前应当充分理解当前使用的Flink版本特点及其局限性;同时也要考虑特定应用场景的需求特征来定制化解决方案。这包括但不限于并行度设定、内存分配策略等方面的选择[^2]。 #### 数据流模式优化 采用广播变量机制可作为一种有效的手段用于降低主数据流转过程中所需维护的状态量级。当存在一对多关系的数据集间需频繁交互时,将较小规模的一方作为广播状态保存下来供另一方查询匹配使用不失为明智之举。此方式特别适用于维表Join操作中,其中一方变动相对较少但又必须保持最新记录的情况[^3]。 ```sql -- 创建临时视图以支持后续JOIN操作 CREATE TEMPORARY VIEW dim_table AS SELECT * FROM kafka_source; -- 定义Temporal Table Function以便获取指定时间点上的历史快照 CREATE FUNCTION hist_dim_table AS 'com.example.HistoricalDimTableFunction'; -- 执行带有时态条件约束的JOIN语句 SELECT o.order_id, d.product_name FROM orders o LEFT JOIN LATERAL TABLE(hist_dim_table(o.event_time)) AS d ON o.product_id = d.id; ``` 上述代码片段展示了如何利用Flink SQL实现基于时间戳的历史维度表连接功能,从而确保每次都能准确捕捉到事件发生瞬间对应的最恰当的产品名称信息。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值