1、先前准备:
import torch
import torchvision
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
2、保存方式1:
# 1、
torch.save(vgg16, 'vgg16_method1.pth') # 不仅保存了模型结构,也保存了参数
结果如下:
这种方式不仅保存了网络模型的结构,也保存了网络模型的参数。
3、加载方式1:
import torch
# from 21、model_save import *
# 方式1 -> 保存方式1 , 加载模型
import torchvision
from torch import nn
model = torch.load('vgg16_method1.pth')
# print(model)
4、保存方式2:
# 2、(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_method2.pth') # 只保存网络模型的参数
结果如下:
该方法只保存模型参数。
5、加载方式2:
# 方式2 -> 保存方式2 , 加载模型
# 要新建网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_method2.pth'))
print(vgg16)
6、方式1的陷阱:
自己定义网络时:
# 陷阱
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(3,64,kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
tudui = Tudui()
torch.save(tudui, 'tudui_method1.pth')
加载时,需要先导入网络结构:
# 陷阱
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1 = nn.Conv2d(3,64,kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
# tudui = Tudui()
model = torch.load('tudui_method1.pth')
print(model)