小土堆:Pytorch深度学习:网络模型的保存与读取

文章介绍了在PyTorch中使用torch.save和torch.load进行模型结构和参数保存的两种方式,以及如何加载预训练的VGG16模型和自定义模型。强调了保存方式2(仅保存参数)的官方推荐和加载时保持一致性的重要性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

网络模型的保存

import torch # 导入torch库
import torchvision # 导入torchvision库
from torch import nn # 导入torch中的nn模块(神经网络模块)

# 创建一个不经过预训练的VGG16模型
vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1,将模型结构和模型参数一起保存为一个文件
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2,只保存模型参数(官方推荐的方式)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")

# 创建一个自定义的模型Yang
class Yang(nn.Module):
def __init__(self):
super(Yang, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

def forward(self, x):
x = self.conv1(x)
return x

# 创建一个Yang对象
yang = Yang()

# 将Yang模型保存为文件
torch.save(yang, "yang_method1.pth")

以上代码演示了如何在PyTorch中保存模型的两种方式。

详细解释如下:
- 首先,我们导入了torch和torchvision库,并从torch中导入nn模块。
- 接着,我们使用`torchvision.models.vgg16(pretrained=False)`创建了一个不经过预训练的VGG16模型,并将其赋值给变量`vgg16`。
- 然后,我们使用`torch.save(vgg16, "vgg16_method1.pth")`将VGG16模型以保存方式1保存为一个文件,其中包含了模型的结构和参数。
- 接着,我们使用`torch.save(vgg16.state_dict(), "vgg16_method2.pth")`将VGG16模型的参数以保存方式2保存为一个文件,这是官方推荐的保存模型的方式。
- 然后,我们创建了一个自定义的模型`Yang`,其中包含了一个卷积层`conv1`。
- 在接下来的代码中,我们创建了一个`Yang`对象,并使用`torch.save(yang, "yang_method1.pth")`将其保存为文件。然而,需要注意的是,直接使用`torch.save`保存自定义模型时会陷入一个陷阱,这种方式保存的模型文件通常会比较大,因为它会保存整个模型的结构和参数,而不仅仅是参数。

请注意,保存模型时需要注意选择适当的保存方式,根据实际情况选择保存模型的方式,保存模型的方式1包括模型的结构和参数,而方式2只保存了模型的参数。同时,为了避免陷阱,最好只保存自定义模型的参数,而不是整个模型。

网络模型的读取

这段代码是一个PyTorch程序,主要涉及加载预训练的VGG16模型以及另外一个自定义模型的加载。下面是对代码的注释解释:

import torch
from model_save import *  # 导入自定义模型保存相关的代码
import torchvision
from torch import nn

# 从文件加载方式1保存的模型
model = torch.load("vgg16_method1.pth")

# 从文件加载方式2保存的模型
vgg16 = torchvision.models.vgg16(pretrained=False)  # 使用预训练参数初始化VGG16模型
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))  # 加载自定义保存的模型参数

# 陷阱1 - 注释掉的部分
# 定义了一个自定义的神经网络模型类 Yang
# class Yang(nn.Module):
#     def __init__(self):
#         super(Yang, self).__init__()
#         self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
#     def forward(self, x):
#         x = self.conv1(x)
#         return x

# 从文件加载方式1保存的 Yang 模型
model = torch.load('yang_method1.pth')

print(model)  # 打印加载的模型

主要内容包括:

1. 代码开始导入了必要的库和模块。

2. 方式1:使用 `torch.load` 函数从文件加载以方式1保存的模型("vgg16_method1.pth")。这里的 `model` 是预训练的VGG16模型。

3. 方式2:首先创建了一个预训练的VGG16模型,然后使用 `load_state_dict` 方法从文件加载以方式2保存的模型参数("vgg16_method2.pth")。

4. 注释掉的部分:包含了一个自定义的神经网络模型类 `Yang`,但是在这段代码中未使用。

5. 从文件加载了以方式1保存的自定义 `Yang` 模型("yang_method1.pth")。

6. 最后,打印加载的模型。

需要注意的是,模型的保存和加载方式可以影响到加载后的模型状态。在方式1中,整个模型对象被保存,因此可以直接加载,而方式2中,只保存了模型的参数,需要预先创建模型结构并加载参数。在加载模型时,确保保存和加载方式一致以防止错误。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值