pytorch中保存和加载模型是绑在一起的。
这里我需要注意一下不同的保存方式对应不同的读取方式,两者各有利弊。
首先说说pytorch.save()这个函数,可以参考官网:pytroch.save。
简而言之,这个函数可以保存任意的东西,比如tensor或者模型,或者仅仅是模型的参数。
如果将保存对象局限在模型上,通常来说我们有两种方式:直接保存所有的模型,只保存模型中的参数(模型结构就保存了)。以下分别说说两种不同的方式。
为了说明,我们先建立一个简单的模型。
import torch
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, in_c, out_c, ngf=64):
super(Generator, self).__init__()
model = []
model += [
nn.Conv2d(in_c, ngf, 3, 2, 1),
nn.ReLU(),
nn.BatchNorm2d(ngf),
nn.Conv2d(ngf, out_c, 3, 2, 1)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
netG = Generator(3, 3)
input = torch.zeros(10, 3, 256, 256)
output = netG(input)
直接保存所有模型并读取
直接使用简单粗暴的方式保存:
torch.save(netG, 'netG.pt')
对应的,我们可以这样读取模型
netC = torch.load('netG.pt')
input = torch.zeros(10, 3, 256, 256)
output = netC(input)
正常情况如下(警告先忽略):
只保存模型中的参数并读取
我们说模型的参数保存在网络的state_dict中,使用这个就可以读取网络的参数了。
torch.save({'netG': netG.state_dict()}, 'model_test.pt')
对应的加载模型的方式如下:
netD = Generator(3, 3)
state_dict = torch.load('model_test.pt')
netD.load_state_dict(state_dict['netG'])
input = torch.zeros(10, 3, 256, 256)
output = netD(input)
总结
我们可以看到第一种方法可以直接保存模型,加载模型的时候直接把读取的模型给一个参数就行。而第二种方法则只是保存参数,在读取模型参数前要先定义一个模型(模型必须与原模型相同的构造),然后对这个模型导入参数。虽然麻烦,但是可以同时保存多个模型的参数,而第一种方法则不能,而且第一种方法有时不能保证模型的相同性(你读取的模型并不是你想要的)。
总的来说,我们一般来选择第二种来保存和读取。
退一步讲,如何保存模型决定了如何读取模型。