1、模型的保存方式一,直接输出模型的网络结构
陷阱:在用此种保存方法加载模型时,需要把该模型包含的类导入到py文件中,比较麻烦,否则无法加载,或者用from ...import...的方式导入。
import torch
import torchvision
Vgg16=torchvision.models.vgg16(pretrained=False,progress=True)
torch.save(Vgg16,'vgg16_dict.pth')
model=torch.load('vgg16_dict.pth') #模型的加载,输出模型网络结构
print(model)
2、模型保存方式二,把模型的状态保存成字典的形式,状态参数(官方推荐)
import torch
import torchvision
Vgg16=torchvision.models.vgg16(pretrained=False,progress=True)
torch.save(Vgg16.state_dict(),'vgg16_dict.pth')
model=torch.load('vgg16_dict.pth') #加载,全是参数
print(model)
通过第二种方式如何输出模型的结构呢?