实践教程 | Pytorch 模型的保存与迁移
在本篇文章中,笔者首先介绍了模型复用的几种典型场景;然后介绍了如何查看Pytorch模型中的相关参数信息;接着介绍了如何载入模型、如何进行追加训练以及进行模型的迁移学习等。
1 引言
各位朋友大家好,欢迎来到月来客栈。今天要和大家介绍的内容是如何在Pytorch框架中对模型进行保存和载入、以及模型的迁移和再训练。一般来说,最常见的场景就是模型完成训练后的推断过程。一个网络模型在完成训练后通常都需要对新样本进行预测,此时就只需要构建模型的前向传播过程,然后载入已训练好的参数初始化网络即可。
第2个场景就是模型的再训练过程。一个模型在一批数据上训练完成之后需要将其保存到本地,并且可能过了一段时间后又收集到了一批新的数据,因此这个时候就需要将之前的模型载入进行在新数据上进行增量训练(或者是在整个数据上进行全量训练)。
第3个应用场景就是模型的迁移学习。这个时候就是将别人已经训练好的预模型拿过来,作为你自己网络模型参数的一部分进行初始化。例如:你自己在Bert模型的基础上加了几个全连接层来做分类任务,那么你就需要将原始BERT模型中的参数载入并以此来初始化你的网络中的Bert部分的权重参数。
在接下来的这篇文章中,笔者就以上述3个场景为例来介绍如何利用Pytorch框架来完成上述过程。
2 模型的保存与复用
在Pytorch中,我们可以通过torch.save()和torch.load()来完成上述场景中的主要步骤。下面,笔者将以之前介绍的LeNet5网络模型为例来分别进行介绍。不过在这之前,我们先来看看Pytorch中模型参数的保存形式。
2.1 查看网络模型参数
(1)查看参数
首先定义好LeNet5的网络模型结构,如下代码所示:
class LeNet5(nn.Module):
def __init__(self, ):
super(LeNet5, self).__init__()
self.conv = nn.Sequential( # [n,1,28,28]
nn.Conv2d(1, 6, 5, padding=2), # in_channels, out_channels, kernel_size
nn.ReLU(), # [n,6,24,24]
nn.MaxPool2d(2, 2), # kernel_size, stride [n,6,14,14]
nn.Conv2d(6, 16, 5), # [n,16,10,10]
nn.ReLU(),
nn.MaxPool2d(2, 2)) # [n,16,5,5]
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10))
def forward(self, img):
output = self.conv(img)
output = self.fc(output)
return output
在定义好LeNet5这个网络结构的类之后,只要我们完成了这个类的实例化操作,那么网络中对应的权重参数也都完成了初始化的工作,即有了一个初始值。同时,我们可以通过如下方式来访问:
Print model’s state_dict
print(“Model’s state_dict:”)
for param_tensor in model.state_dict():
print(param_tensor, “\t”, model.state_dict()[param_tensor].size())
其输出的结果为:
conv.0.weight torch.Size([6, 1, 5, 5])
conv.0.bias torch.Size([6])
conv.3.weight torch.Size([16, 6, 5, 5])
可以发现,网络模型中的参数model.state_dict()其实是以字典的形式(实质上是collections模块中的OrderedDict)保存下来的:
print(model.state_dict().keys())
odict_keys([‘conv.0.weight’, ‘conv.0.bias’, ‘conv.3.weight’,
‘conv.3.bias’, ‘fc.1.weight’, ‘fc.1.bias’, ‘fc.3.weight’, ‘fc.3.bias’,
‘fc.5.weight’, ‘fc.5.bias’])
(2)自定义参数前缀
同时,这里值得注意的地方有两点:①参数名中的fc和conv前缀是根据你在上面定义nn.Sequential()时的名字所确定的;②参数名中的数字表示每个Sequential()中网络层所在的位置。例如将网络结构定义成如下形式:
class LeNet5(nn.Module):
def __init__(self, ):
super(LeNet5, self).__init__()
self.moon = nn.Sequential( # [n,1,28,28]
nn.Conv2d(1, 6, 5, padding=2), # in_channels, out_channels, kernel_size
nn.ReLU(), # [n,6,24,24]
nn.MaxPool2d