Onnx 使用
Onnx 介绍
Onnx (Open Neural Network Exchange) 的本质是一种 Protobuf 格式文件,通常看到的 .onnx 文件其实就是通过 Protobuf 序列化储存的文件。onnx-ml.proto 通过 protoc (Protobuf 提供的编译程序) 编译得到 onnx-ml.pb.h 和 onnx-ml.pb.cc 或 onnx_ml_pb2.py,然后用 onnx_ml.pb.cc 和代码来操作 onnx 模型文件,实现增删改操作。onnx-ml.proto 则是描述 onnx 文件如何组成和结构,用于作为操控 onnx 的参照。但是这个.proto 文件(里面用的是 protobuf 语法)只是一个中间表示文件,不具备任何能力,即并不面向存储和传输(序列化,反序列化,读写)。所以需要用 protoc 编译 .proto 文件,是将 .proto 文件编译成不同语言的实现,得到 .cc 和 .py 文件这两个接口文件,这样不同语言中的数据就可以和自定义的结构化数据格式的数据进行交互。
onnx-ml.proto
查看 onnx-ml.proto 文件可以看到每个 proto 的不同的数据结构,message 就是结构化数据的关键字,用于描述数据的字段、类型和层次结构。之后加载了 onnx 模型后,就可以按照这个文件内部记录的数据结构来访问模型中的数据。
使用 repeated 就表示内部的属性是数组,使用 optional 就表示数据可选。NodeProto 中 input 就是一个数组,其中储存着 string,需要使用索引访问,name 就是一个 string。这些参数后面的数值表示每个属性特定的 id,这些 id 是不能冲突的,官方已经指定好了。
Onnx 结构
-
onnx.model: 这是一个具体的 ONNX 模型实例,它是一个
ModelProto
对象,可以通过onnx.helper.make_model
来构建。它包含了模型的所有信息,包括元数据、图(graph)、初始参数等。代码中对应ModelProto
,其中opset_import
是OperatorSetIdProto
数据结构的数组;graph
是GraphProto
数据结构,一个 model 对应一个 graph。 -
onnx.model.graph: 这是 onnx 模型的计算图,它是一个
GraphProto
对象,可以通过onnx.helper.make_graph
来构建。计算图是一种描述模型计算过程的图形结构,它由一系列的节点 (node) 组成,这些节点代表了模型中的各种操作 (如卷积、激活函数等)。代码中对应GraphProto
,其中node
是NodeProto
数据结构的数组;initializer
是一个TensorProto
数据结构的数组;sparse_initializer
是一个SparseTensorProto
数据结构的数组;input
、output
和value_info
是ValueInfoProto
数据结构的数组。 -
onnx.model.graph.node: 这是计算图中的一个节点,它是一个
NodeProto
对象,可以通过onnx.helper.make_node
。每个节点代表了一个操作,它有一定数量的输入和输出,这些输入和输出都是张量。节点还有一个操作类型 (如"Conv"、"Relu"等) 和一些特定的参数 (如卷积核的大小、步长、填充等)。代码中对应NodeProto
,其中input
和output
是string
类型数组;attribute
是一个AttributeProto
数据结构,用于定义 node 属性,常见用法是将 (key, value) 传入 Proto 中;op_type
是一个string
,需要对应 onnx 提供的 operators。 -
onnx.model.graph.initializer: 这是模型的初始化参数,它是一个
TensorProto
对象,可以通过onnx.helper.make_tensor
。每个TensorProto
对象都包含了一个张量的所有信息,包括数据类型、形状、数据等。模型的所有权重和偏置都包含在这个列表中。代码中对应TensorProto
,其中dims
是int64
类型数组;raw_data
是bytes
类型。 -
onnx.model.graph.input/output: 这是模型的输入/输出信息,它是一个
ValueInfoProto
对象,可以通过onnx.helper.make_value_info
或onnx.helper.make_tensor_value_info
。每个ValueInfoProto
对象描述了一个输入/输出的信息,包括名称、数据类型、形状等。主要是用于标记哪些节点是输入/输出。代码中对应ValueInfoProto
,其中type
是一个TypeProto
数据结构,内部定义了标准 onnx 数据类型。
ONNX 模型的计算图中的节点可以有多种类型,其中包括 Constant,表示一种特殊的操作,它的作用是在计算过程中提供一个常量张量。这种常量张量的值在模型训练过程中不会改变,因此被视为常量。例如,对于大小为 (bs, N, H, W, 2) 的 anchorgrid 张量(其中,bs 是批处理大小,N 是锚框(anchor box)的种类数量,H 和 W 分别代表特征图的高度和宽度,最后的** 2 **代表每个锚框的宽度和高度),它可以被存储在一个 Constant 类型的节点中。值得注意的是,当使用ONNX图形可视化工具(如Netron)时,Constant 类型的节点可能不会显示出来。这是因为这些节点并不涉及任何计算操作,只是提供了一个常量张量。
另外,还有一种类型为"Identity"的节点。"Identity"节点表示一种标识操作,它的输出和输入完全相同。这种节点通常用于在需要保持计算图结构完整性的情况下,将某些张量传递到计算图的下一层。换句话说,它不会改变传递给它的任何信息,可以被视为一个透明的或者无操作的节点。
Onnx 模型生成
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.onnx
import os
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
self.relu = nn.ReLU()
self.conv.weight.data.fill_(1)
self.conv.bias.data.fill_(0)
def forward(self, x):
x = self.conv(x)
x = self.relu