网络模型的保存
import torch # 导入torch库
import torchvision # 导入torchvision库
from torch import nn # 导入torch中的nn模块(神经网络模块)
# 创建一个不经过预训练的VGG16模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,将模型结构和模型参数一起保存为一个文件
torch.save(vgg16, "vgg16_method1.pth")
# 保存方式2,只保存模型参数(官方推荐的方式)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
# 创建一个自定义的模型Yang
class Yang(nn.Module):
def __init__(self):
super(Yang, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
# 创建一个Yang对象
yang = Yang()
# 将Yang模型保存为文件
torch.save(yang, "yang_method1.pth")
以上代码演示了如何在PyTorch中保存模型的两种方式。
详细解释如下:
- 首先,我们导入了torch和torchvision库,并从torch中导入nn模块。
- 接着,我们使用`torchvision.models.vgg16(pretrained=False)`创建了一个不经过预训练的VGG16模型,并将其赋值给变量`vgg16`。
- 然后,我们使用`torch.save(vgg16, "vgg16_method1.pth")`将VGG16模型以保存方式1保存为一个文件,其中包含了模型的结构和参数。
- 接着,我们使用`torch.save(vgg16.state_dict(), "vgg16_method2.pth")`将VGG16模型的参数以保存方式2保存为一个文件,这是官方推荐的保存模型的方式。
- 然后,我们创建了一个自定义的模型`Yang`,其中包含了一个卷积层`conv1`。
- 在接下来的代码中,我们创建了一个`Yang`对象,并使用`torch.save(yang, "yang_method1.pth")`将其保存为文件。然而,需要注意的是,直接使用`torch.save`保存自定义模型时会陷入一个陷阱,这种方式保存的模型文件通常会比较大,因为它会保存整个模型的结构和参数,而不仅仅是参数。
请注意,保存模型时需要注意选择适当的保存方式,根据实际情况选择保存模型的方式,保存模型的方式1包括模型的结构和参数,而方式2只保存了模型的参数。同时,为了避免陷阱,最好只保存自定义模型的参数,而不是整个模型。
网络模型的读取
这段代码是一个PyTorch程序,主要涉及加载预训练的VGG16模型以及另外一个自定义模型的加载。下面是对代码的注释解释:
import torch
from model_save import * # 导入自定义模型保存相关的代码
import torchvision
from torch import nn
# 从文件加载方式1保存的模型
model = torch.load("vgg16_method1.pth")
# 从文件加载方式2保存的模型
vgg16 = torchvision.models.vgg16(pretrained=False) # 使用预训练参数初始化VGG16模型
vgg16.load_state_dict(torch.load("vgg16_method2.pth")) # 加载自定义保存的模型参数
# 陷阱1 - 注释掉的部分
# 定义了一个自定义的神经网络模型类 Yang
# class Yang(nn.Module):
# def __init__(self):
# super(Yang, self).__init__()
# self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
#
# def forward(self, x):
# x = self.conv1(x)
# return x
# 从文件加载方式1保存的 Yang 模型
model = torch.load('yang_method1.pth')
print(model) # 打印加载的模型
主要内容包括:
1. 代码开始导入了必要的库和模块。
2. 方式1:使用 `torch.load` 函数从文件加载以方式1保存的模型("vgg16_method1.pth")。这里的 `model` 是预训练的VGG16模型。
3. 方式2:首先创建了一个预训练的VGG16模型,然后使用 `load_state_dict` 方法从文件加载以方式2保存的模型参数("vgg16_method2.pth")。
4. 注释掉的部分:包含了一个自定义的神经网络模型类 `Yang`,但是在这段代码中未使用。
5. 从文件加载了以方式1保存的自定义 `Yang` 模型("yang_method1.pth")。
6. 最后,打印加载的模型。
需要注意的是,模型的保存和加载方式可以影响到加载后的模型状态。在方式1中,整个模型对象被保存,因此可以直接加载,而方式2中,只保存了模型的参数,需要预先创建模型结构并加载参数。在加载模型时,确保保存和加载方式一致以防止错误。