Onnx_Basic
一切开始前先检查自己的onnx版本
pip list | grep onnx
期待的4个主要的SDK是
onnx 1.14.0
onnx-graphsurgeon 0.3.27
onnxruntime 1.15.1
onnxsim 0.4.33
3 学会如何导出ONNX, 分析ONNX
3.1 简单复习一下pytorch
定义一个模型, 这个模型实质上是一个线性层 (nn.Linear)。线性层执行的操作是 y = x * W^T + b,其中 x 是输入,W 是权重,b 是偏置。
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights, bias=False):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear.weight.copy_(weights)
def forward(self, x):
x = self.linear(x)
return x
定义一个infer的case, 权重的形状通常为 (out_features, in_features),这里复习一下矩阵相乘, 这里的in_features(X)的shape是[4], 而我们希望模型输出的是[3], 那么y = x * W^T + b可以知道W^T需要是[4, 3], nn.Linear会帮我们转置, 所以这里的W的shape是[3, 4]

def infer():
in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
x = model(in_features)
print("result is: ", x)
3.2 torch.export.onnx参数
这里就是infer完了之后export onnx, 重点看一下这里的参数,
-
model (torch.nn.Module): 需要导出的PyTorch模型,它应该是
torch.nn.Module的一个实例。 -
args (tuple or Tensor): 一个元组,其中包含传递给模型的输入张量,用于确定ONNX图的结构。在您的代码中,您传递了一个包含一个张量的元组,这指示您的模型接受单个输入。
-
f (str): 要保存导出模型的文件路径。在您的代码中,该模型将被保存到“…/models/example.onnx”路径。
-
input_names (list of str): 输入节点的名字的列表。这些名字可以用于标识ONNX图中的输入节点。在您的代码中,您有一个名为“input0”的输入。
-
output_names (list of str): 输出节点的名字的列表。这些名字可以用于标识ONNX图中的输出节点。在您的代码中,您有一个名为“output0”的输出。
-
opset_version (int): 用于导出模型的ONNX操作集版本。
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
model = Model(4, 3, weights)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
# 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example.onnx",
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
print("Finished onnx export")
当然可以。以下是torch.onnx.export函数中参数的解释:
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example.onnx",
input_names = ["input0"],
output_names = ["output0"],
opset_version = 12)
3.3 多个输出头

模型的定义上就要有多个
class Model(torch.nn.Module):
def __init__(self, in_features, out_features, weights1, weights2, bias=False):
super().__init__()
self.linear1 = nn.Linear(in_features, out_features, bias)
self.linear2 = nn.Linear(in_features, out_features, bias)
with torch.no_grad():
self.linear1.weight.copy_(weights1)
self.linear2.weight.copy_(weights2)
def forward(self, x):
x1 = self.linear1(x)
x2 = self.linear2(x)
return x1, x2
输出的时候只要更改output_names的参数就可以了
def export_onnx():
input = torch.zeros(1, 1, 1, 4)
weights1 = torch.tensor([
[1, 2, 3, 4],
[2, 3, 4, 5],
[3, 4, 5, 6]
],dtype=torch.float32)
weights2 = torch.tensor([
[2, 3, 4, 5],
[3, 4, 5, 6],
[4, 5, 6, 7]
],dtype=torch.float32)
model = Model(4, 3, weights1, weights2)
model.eval() #添加eval防止权重继续更新
# pytorch导出onnx的方式,参数有很多,也可以支持动态size
torch.onnx.export(
model = model,
args = (input,),
f = "../models/example_two_head.onnx",
input_names = ["input0"]

本文围绕ONNX展开,先介绍导出ONNX所需检查的版本及主要SDK。接着详细阐述导出ONNX的方法,包括复习PyTorch、分析torch.export.onnx参数等。还提及多个输出头、动态形状等情况,以及onnxsim工具的使用。最后通过YOLOv5和Swin - Transformer两个实战案例,展示ONNX的优化与导出。
最低0.47元/天 解锁文章
1690

被折叠的 条评论
为什么被折叠?



