pytorch神经网络模型的保存与读取
#model_save
vgg16 = torchvision.models.vgg16(pretrained = False)
#保存方式1
torch.save(vgg16, "vgg16_method1.pth") #保存模型和参数
#方式2, 官方推荐,更小
torch.save(vgg16.state_dict(),"vgg16_method2.pth") #参数保存为字典,模型不再保存
#陷阱
class Model_test(nn,Module):
def __init__():
super(Model_test,self).__init__()
self.conv1 = nn.Conv2d(3,64,kernel_size = 3)
def forward(self, x):
x = self.conv1(x)
return x
#model_load
#方式1,加载是模型
vgg16 = torch.load("vgg16_method1.pth")
print(model)
#方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained = False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16 )
#陷阱(需要copy过来)
class Model_test(nn,Module):
def __init__():
super(Model_test,self).__init__()
self.conv1 = nn.Conv2d(3,64,kernel_size = 3)
def forward(self, x):
x = self.conv1(x)
return x
model = torch.load(model_test_method1.pth)