ONNX 输入batch修改

ONNX 输入batch修改

导出的onnx模型分为静态和动态输入两种,但一般用户会在导出后进行onnxsim操作,导致某些非全卷积的模型修改batch失败,比如transformer类其中reshape的attr属性会固定,修改相当麻烦,需要从源头重新导出是最佳选择,本文讨论在没有源头的情况下,尽量完成修改batch的修改,并固化每个op的输入输出tensor为正确shape。

依赖

  • onnx
  • onnxsim
import sys
import onnx
from onnx import shape_inference, helper
from onnxsim import simplify


filepath = sys.argv[1]
batch = int(sys.argv[2])
opset_version = int(sys.argv[3])

model = onnx.load(filepath)

# 创建一个新的模型图
new_graph = helper.make_graph(
    nodes=model.graph.node,                # 保留所有节点
    name="new_graph",                      # 保留原图名称
    inputs=model.graph.input,              # 保留输入
    outputs=model.graph.output,            # 保留输出
    initializer=model.graph.initializer,   # 保留初始化器
    value_info=None
)

# 构造一个新的模型
if not any(opset.domain == "" for opset in model.opset_import):
    model.opset_import.append(helper.make_opsetid(domain="", version=opset_version))
    
new_model = helper.make_model(new_graph, producer_name=model.producer_name, opset_imports=model.opset_import)
for idx, _input in enumerate(new_model.graph.input):
    _input.type.tensor_type.shape.dim[0].dim_value = batch
for idx, _output in enumerate(new_model.graph.output):
    _output.type.tensor_type.shape.dim[0].dim_value = batch
    
model_sim, success = simplify(new_model)
if not success:
    print("simplify failed")
    onnx.save(new_model, filepath.replace(".onnx", f"_b{batch}.onnx"))
else:
    onnx.save(model_sim, filepath.replace(".onnx", f"_b{batch}.onnx"))


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

上单之光

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值