保存和提取神经网络

本文介绍了使用PyTorch保存和加载神经网络的方法,包括保存整个网络模型和仅保存网络参数两种方式,并通过一个简单的实例展示了如何实现这些操作。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

之前我们学习了如何去搭建训练神经网络,接下来就要学习如何将训练好的神经网络进行保存,在需要的时候进行提取呢?

保存神经网络有两种方法:

  • 保存整个神经网络
  • 保存神经网络的参数,不保存他的结构

提取也有两种方法:

  • 直接进行提取
  • 先创建一个跟提取的神经网络一模一样的神经网络,然后再将各个参数传入神经网络即可。

 小demo:

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import  Variable

#fake data
x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y=x.pow(2)+0.2*torch.rand(x.size())

x,y=Variable(x),Variable(y)
def save():
    net=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )

    optimizer=torch.optim.SGD(net.parameters(),lr=0.2)
    loss_func=torch.nn.MSELoss()

    for i in range(100):
        prediction=net(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    #保存整个网络
    torch.save(net,'net.pkl')
    #保存网络的参数
    torch.save(net.state_dict(),'net_params.pkl')
    plt.subplot(131)
    plt.title('net')
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

def restore_net():
    #提取整个网络
    net1=torch.load('net.pkl')
    prediction=net1(x)
    plt.subplot(132)
    plt.title('net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

def restore_params():
    #先创建一个跟 原来网络一样的网络结构
    net2=torch.nn.Sequential(
        torch.nn.Linear(1,10),
        torch.nn.ReLU(),
        torch.nn.Linear(10,1)
    )
    #将参数传入网络
    net2.load_state_dict(torch.load('net_params.pkl'))
    prediction=net2(x)
    plt.subplot(133)
    plt.title('net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()
save()
restore_net()
restore_params()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值