深度学习5 神经网络、参数初始化

一、模型保存与加载

 1、序列化方式

        保存方式:torch.save(model, "model.pkl")

        打开方式:model = torch.load("model.pkl", map_location="cpu")

​
import torch
import torch.nn as nn


class MyModle(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

model = MyModle(input_size=128, output_size=32)
# 序列化方式保存模型对象
torch.save(model, "model.pkl")


# 注意设备问题
model = torch.load("model.pkl", map_location="cpu")
print(model)

​

2、保存模型参数

        设置需要保存的模型参数:

save_dict = {

        "init_params": {

            "input_size": 128,  # 输入特征数

            "output_size": 32,  # 输出特征数

        },

        "accuracy": 0.99,  # 模型准确率

        "model_state_dict": model.state_dict(),

        "optimizer_state_dict": optimizer.state_dict(),

    }

保存模型参数:torch.save(save_dict, "名称.pth"),一般使用 pth 作为后缀

创建新模型时调用保存的模型参数:

加载模型参数:torch.load("名称.pth")

input_size =  save_dict["init_params"]["input_size"]

import torch
import torch.nn as nn
import torch.optim as optim

class MyModle(nn.Module):
    def __init__(self, input_size, output_size):
        super(MyModle, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        output = self.fc3(x)
        return output

save_dict = torch.load("模型参数保存名称.pth")
model = MyModle(
    input_size=save_dict["init_params"]["input_size"],
    output_size=save_dict["init_params"]["output_size"],
)
# 初始化模型参数
model.load_state_dict(save_dict["model_state_dict"])
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 初始化优化器参
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值