本篇博客主要介绍如何在PyTorch中保存和加载模型训练的结果。
对训练结果进行保存,有两种方式,一种是保存整个网络,另一种是保存训练好的参数,相对而言,第二种方式具有更高的效率。
下面是示例代码:
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
# 生成假数据
# torch.unsqueeze() 的作用是将一维变二维,torch只能处理二维的数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape(100, 1)
# 0.2 * torch.rand(x.size())增加噪点
y = x.pow(2) + 0.2 + 0.2 * torch.rand(x.size())
# 将Tensor转换为torch
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
# 保存神经网络
def save():
# 搭建神经网络
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 优化器:随