pytorch中模型和参数的存储

本文详细介绍了PyTorch中torch.nn.Module模型的state_dict属性,该属性是Python字典对象,用于存储模型的可学习参数,如权重和偏差。同时,也讲解了torch.optim优化器的state_dict属性,包含了优化器的状态信息和超参数。文章通过一个分类器模型实例,演示了如何访问和打印模型及优化器的state_dict。
部署运行你感兴趣的模型镜像

 

在PyTorch中, torch.nn.Module 模型的可学习参数(即权重和偏差)包含在模型的参数中,
(使用 c可以进行访问)。
state_dict 是Python字典对象,它将每一层映射到其参数张量。注意,只有具有可学习参数的层(如卷积层,线性层等)的模型才具有 state_dict 这一项。目标优化 torch.optim 也有 state_dict 属性,它包含有关优化器的状态信息,以及使用的超参数。
因为state_dict的对象是Python字典,所以它们可以很容易的保存、更新、修改和恢复,为PyTorch模型和优化器添加了大量模块。
下面通过从简单模型训练一个分类器中来了解一下 state_dict 的使用

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class theModelClass(nn.Module):
    def __init__(self):
        super(theModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = theModelClass()

# 初始化 优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 打印模型的状态字典
print("Model's state_dict:" )
for param_tensor in model.state_dict():
    print(param_tensor, '\t', model.state_dict()[param_tensor].size())

# 打印优化器的转台字典
print(" Optimizer's state_dict; ")
for var_name in optimizer.state_dict():
    print(var_name, '\t', optimizer.state_dict()[var_name])

# 输出


"""
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])

conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])

fc1.weight torch.Size([120, 400])  
fc1.bias torch.Size([120])

fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])

fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])

Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay':
0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072,
4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

"""

 

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

NCNN是一个为移动端优化的高性能神经网络推理框架,而PyTorch是一个广泛使用的深度学习训练推理框架,二者在参数存储格式上存在显著差异。 #### 数据组织方式 - **PyTorch**:PyTorch使用Python对象来组织模型参数,每个层的参数(如权重偏置)被存储为`torch.Tensor`对象,这些对象可以方便地进行操作计算。模型的所有参数可以通过`model.state_dict()`方法获取,它返回一个字典,字典的键是参数的名称,值是对应的`torch.Tensor`对象。例如: ```python import torch import torch.nn as nn # 定义一个简单的线性层 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.linear = nn.Linear(10, 1) def forward(self, x): return self.linear(x) model = SimpleModel() state_dict = model.state_dict() for key, value in state_dict.items(): print(key, value.shape) ``` - **NCNN**:NCNN使用二进制文件来存储模型参数,这些文件通常包含两个部分:模型结构参数数据参数数据以二进制格式存储,按照层的顺序依次排列。NCNN的模型文件通常有两个,一个是`.param`文件,用于存储模型结构信息;另一个是`.bin`文件,用于存储参数数据。 #### 数据类型精度 - **PyTorch**:PyTorch支持多种数据类型,如`torch.float32`、`torch.float64`等。在训练推理过程中,可以根据需要选择合适的数据类型。默认情况下,PyTorch使用`torch.float32`。 - **NCNN**:NCNN主要使用单精度浮点数(`float32`)来存储参数,以在保证一定精度的同时,减少内存占用计算量,适合在移动端设备上运行。 #### 存储顺序 - **PyTorch**:PyTorch的张量存储顺序遵循NumPy的约定,即C风格的连续存储。对于多维张量,数据在内存中按行优先的顺序存储。 - **NCNN**:NCNN的参数存储顺序与具体的层操作有关,但通常会考虑到内存访问的效率计算的优化。在一些情况下,NCNN可能会对参数进行重新排列,以提高内存访问的局部性。 #### 跨平台兼容性 - **PyTorch**:PyTorch模型参数存储格式与PythonPyTorch库紧密相关,在不同的操作系统硬件平台上,只要安装了兼容版本的PyTorch,就可以方便地加载使用模型。 - **NCNN**:NCNN的二进制文件格式具有较好的跨平台兼容性,可以在不同的操作系统(如Android、iOS、Linux等)硬件平台(如ARM、x86等)上使用,只要有对应的NCNN库支持。
评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值