一、模型的保存
1、方法一:torch.save
保存了模型的结构加参数
import torch
from torch import nn
import torchvision
vgg16=torchvision.models.vgg16(weights=None)
#保存方式1 保存状态+结构
torch.save(vgg16,"vgg16_method1.path")
可以看到同目录下的模型文件:
二、模型读取
1、方法一、torch.load
import torch
#方式一对应于保存方式一:
model1=torch.load("vgg16_method1.path")
print(model1)
一、模型的保存
2、保存方式2
以字典形式保存了模型的状态(参数)
#保存方式2 保存状态(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.path")
从终端打开保存文件所在的文件夹
输入 dir
(DL) D:\PyCharm\Py_Projects\XiaoTuDui>dir
驱动器 D 中的卷是 academic
卷的序列号是 8621-1046
D:\PyCharm\Py_Projects\XiaoTuDui 的目录
2023/09/18 08:58 <DIR> .
2023/09/18 08:57 <DIR> ..
2023/09/08 14:52 <DIR> .pytest_cache
2023/09/07 09:44 <DIR> datasetP7
……//
2023/09/18 08:56 553,450,705 vgg16_method1.path
2023/09/18 08:56 553,441,041 vgg16_method2.path
29 个文件 1,106,909,735 字节
22 个目录 50,462,588,928 可用字节
可以看到vgg16_method1要比vgg16_method2
二、模型读取
2、方式二、读取方式2
#方式2,加载模型
model2=torch.load("vgg16_method2.path")
print(model2)
可以发现数据是字典形式
转换成模型结构:
#方式2,加载模型
#model2=torch.load("vgg16_method2.pth")
vgg16=torchvision.models.vgg16(weights=None)
vgg16.load_state_dict("vgg16_method2.pth")
print(vgg16)
三、陷阱
我们写一个神经网络并保存:
#陷阱
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
x=self.conv1(x)
return x
tudui=Tudui()
torch.save(tudui,"tudui_method1.pth")
四、陷阱加载
#陷阱
model3=torch.load("tudui_method1.pth")
print(model3)
报错:不能得到Tudui属性,因为没有这个类
AttributeError: Can't get attribute 'Tudui' on <module '__main__' from 'D:\\PyCharm\\Py_Projects\\XiaoTuDui\\P27 模型读取.p
还是需要去把模型复制进来
只是不需要继承tudui=Tudui()
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
x=self.conv1(x)
return x
# tudui=Tudui() //这里不需要写//
model3=torch.load("tudui_method1.pth")
print(model3)
成功!
五、改进
import torch
import torchvision
from torch import nn
from P27 模型的保存 import *
# #方式一对应于保存方式一:
# model1=torch.load("vgg16_method1.path")
# print(model1)
#
# #方式2,加载模型
# #model2=torch.load("vgg16_method2.pth")
# vgg16=torchvision.models.vgg16(weights=None)
# vgg16.load_state_dict("vgg16_method2.pth")
# print(vgg16)
#陷阱
# class Tudui(nn.Module):
# def __init__(self):
# super(Tudui, self).__init__()
# self.conv1=nn.Conv2d(3,64,kernel_size=3)
#
# def forward(self,x):
# x=self.conv1(x)
# return x
# # tudui=Tudui() //这里不需要写//
model3=torch.load("tudui_method1.pth")
print(model3)
保存模型.py的所有代码;
import torch
from torch import nn
import torchvision
vgg16=torchvision.models.vgg16(weights=None)
#保存方式1 保存状态+结构
torch.save(vgg16,"vgg16_method1.path")
#保存方式2 保存状态(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.path")
#陷阱
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.conv1=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
x=self.conv1(x)
return x
tudui=Tudui()
torch.save(tudui,"tudui_method1.pth")
读取模型.py的所有代码;
import torch
import torchvision
from torch import nn
from P27 模型的保存 import *
# #方式一对应于保存方式一:
# model1=torch.load("vgg16_method1.path")
# print(model1)
#
# #方式2,加载模型
# #model2=torch.load("vgg16_method2.pth")
# vgg16=torchvision.models.vgg16(weights=None)
# vgg16.load_state_dict("vgg16_method2.pth")
# print(vgg16)
#陷阱
# class Tudui(nn.Module):
# def __init__(self):
# super(Tudui, self).__init__()
# self.conv1=nn.Conv2d(3,64,kernel_size=3)
#
# def forward(self,x):
# x=self.conv1(x)
# return x
# # tudui=Tudui() //这里不需要写//
model3=torch.load("tudui_method1.pth")
print(model3)