PyTorch学习总结(三)——ONNX

本文介绍了ONNX的用途,它是AI模型的一种开源格式,促进不同框架间的互操作性。重点讲解了PyTorch中torch.onnx模块如何将模型导出到ONNX,并提供了AlexNet从PyTorch转到Caffe2的示例。讨论了导出的局限性,如基于跟踪的exporter对动态模型的支持不足,以及ONNX支持的操作符列表。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

这里写图片描述


1.什么是ONNX

Open Neural Network Exchange (ONNX)是开放生态系统的第一步,它使人工智能开发人员可以在项目的发展过程中选择合适的工具;ONNX为AI models提供了一种开源格式。它定义了一个可以扩展的计算图模型,同时也定义了内置操作符和标准数据类型。最初我们关注的是推理(评估)所需的能力。

Caffe2, PyTorch, Microsoft Cognitive Toolkit, Apache MXNet 和其他工具都在对ONNX进行支持。在不同的框架之间实现互操作性,并简化从研究到产品化的过程,将提高人工智能社区的创新速度。

2.torch.onnx

本文我们将主要介绍PyTorch中自带的torch.onnx模块。该模块包含将模型导出到ONNX IR格式的函数。这些模型可以被ONNX库加载,然后将它们转换成可在其他深度学习框架上运行的模型。

3 End-to-end AlexNet from PyTorch to Caffe2

这里有一个简单的脚本,它将torchvision中预训练的AlexNet导出为ONNX。它运行一个简单的推断,然后将生成的跟踪模型保存到alexnet.proto:

from torch.autograd import Variable
import torch.onnx
import torchvision

dummy_input = Variable(torch.randn(10, 3, 224, 224)).cuda()
model = torchvision.models.alexnet(pretrained=True).cuda()
torch.onnx.export(model, dummy_input, "alexnet.proto", verbose=True)

alexnet.proto是一个二进制的protobuf文件,它包含您导出的模型的网络结构和参数(在这里,模型是AlexNet)。关键参数verbose=True使exporter可以打印出一种人类可读的网络表示:

# All parameters are encoded explicitly as inputs.  By convention,
# learned parameters (ala nn.Module.state_dict) are first, and the
# actual inputs are last.
graph(%1 : Float(64, 3, 11, 11)
      %2 : Fl
### 如何将PyTorch模型导出为ONNX格式 为了将PyTorch模型转换为ONNX格式,可以遵循特定的过程。此过程涉及准备环境、加载预训练模型以及使用`torch.onnx.export()`函数完成实际的转换工作。 #### 准备环境并导入必要的库 在开始之前,确保安装了所需的Python包,如`torch`和`torchvision`。这些工具提供了创建和操作神经网络所需的功能[^1]。 ```python import torch import torchvision ``` #### 创建或加载目标模型实例 对于本例而言,选择了ResNet18作为演示对象——这是一种经典的卷积神经网络架构,在图像分类任务上表现良好。通过设置参数`pretrained=True`可以从官方仓库下载预先训练好的权重;如果不需要初始化权重,则省略该选项即可[^2]。 ```python model = torchvision.models.resnet18(pretrained=True) ``` #### 构建虚拟输入数据 为了让导出器知道如何处理不同形状的数据流,需提供一个代表性的样本张量给定维度大小(batch size, channels, height, width),这里假设批量大小为1,通道数为3(RGB图片),高度宽度均为224像素。 ```python dummy_input = torch.randn(1, 3, 224, 224) ``` #### 执行导出命令 最后一步就是调用`torch.onnx.export()`来进行最终转化。除了上述提到的对象外,还需要指明保存路径名及所使用的OPSet版本号。OPSet指的是运算符集合的标准版本,较高的数值意味着更广泛的支持度和更好的兼容性[^3]。 ```python torch.onnx.export( model, dummy_input, "resnet18.onnx", opset_version=11 ) ``` 以上步骤展示了怎样利用PyTorch内置功能轻松实现从框架内部结构到开放标准中间表达形式之间的转变。值得注意的是,当涉及到更为复杂的场景时(例如存在多个输入/输出节点或多分支逻辑的情况),可能还需额外配置一些高级选项来满足具体需求[^5]。
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值