Pytorch提供了两种模型的保存和加载方法。
一、首先定义一个名为Example的模型
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class Example(nn.Module):
def __init__(self):
super(Example, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 初始化模型
model_init = Examle()
保存的文件后缀一般使用.pt
或.pth
,将保存后的模型使用到测试模式时,需要使用model.eval()
,它表示将drop
和batch nromalization
层设置为测试模式。训练时候,需要通过设置mode.train()
转化为训练模式。
1. 第一种是只保存模型的参数。
使用这种方法保存模型,当测试时候我们需要自己导入模型的结构信息。
#保存模型参数
torch.save(model.state_dict(), PATH)
例:将上面定义的模型model_init
保存在文件夹ckp
,以model.pth
形式保存。
torch.save(model_init.state_dict(),'ckp/model.pth')
当测试时候,我们需要加载保存的模型,来进行测试。
通用方法:
model = ModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
加载我们刚刚保存的模型进行测试,当然我们需要使用定义的网络结构model_init
:
model_init .load_state_dict(torch.load('ckp/model.pth'))
2. 第二种是保存模型的结构和模型的参数
保存模型
torch.save(model, PATH)
例如:保存定义好的model_init
模型
torch.save(model_init , 'ckp/model.pth')
加载模型
model = torch.load(PATH)
model.eval()
例如:加载保存好的model_init
模型
model = torch.load('ckp/model.pth')
model.eval()
当然,第二种模型文件占用的内存要大于第一种方法。