之前我们学习了如何去搭建训练神经网络,接下来就要学习如何将训练好的神经网络进行保存,在需要的时候进行提取呢?
保存神经网络有两种方法:
- 保存整个神经网络
- 保存神经网络的参数,不保存他的结构
提取也有两种方法:
- 直接进行提取
- 先创建一个跟提取的神经网络一模一样的神经网络,然后再将各个参数传入神经网络即可。
小demo:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
#fake data
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x.pow(2)+0.2*torch.rand(x.size())
x,y=Variable(x),Variable(y)
def save():
net=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
loss_func=torch.nn.MSELoss()
for i in range(100):
prediction=net(x)
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#保存整个网络
torch.save(net,'net.pkl')
#保存网络的参数
torch.save(net.state_dict(),'net_params.pkl')
plt.subplot(131)
plt.title('net')
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
def restore_net():
#提取整个网络
net1=torch.load('net.pkl')
prediction=net1(x)
plt.subplot(132)
plt.title('net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
def restore_params():
#先创建一个跟 原来网络一样的网络结构
net2=torch.nn.Sequential(
torch.nn.Linear(1,10),
torch.nn.ReLU(),
torch.nn.Linear(10,1)
)
#将参数传入网络
net2.load_state_dict(torch.load('net_params.pkl'))
prediction=net2(x)
plt.subplot(133)
plt.title('net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
save()
restore_net()
restore_params()