ONNX 输入batch修改
导出的onnx模型分为静态和动态输入两种,但一般用户会在导出后进行onnxsim操作,导致某些非全卷积的模型修改batch失败,比如transformer类其中reshape的attr属性会固定,修改相当麻烦,需要从源头重新导出是最佳选择,本文讨论在没有源头的情况下,尽量完成修改batch的修改,并固化每个op的输入输出tensor为正确shape。
依赖
- onnx
- onnxsim
import sys
import onnx
from onnx import shape_inference, helper
from onnxsim import simplify
filepath = sys.argv[1]
batch = int(sys.argv[2])
opset_version = int(sys.argv[3])
model = onnx.load(filepath)
# 创建一个新的模型图
new_graph = helper.make_graph(
nodes=model.graph.node, # 保留所有节点
name="new_graph", # 保留原图名称
inputs=model.graph.input, # 保留输入
outputs=model.graph.output, # 保留输出
initializer=model.graph.initializer, # 保留初始化器
value_info=None
)
# 构造一个新的模型
if not any(opset.domain == "" for opset in model.opset_import):
model.opset_import.append(helper.make_opsetid(domain="", version=opset_version))
new_model = helper.make_model(new_graph, producer_name=model.producer_name, opset_imports=model.opset_import)
for idx, _input in enumerate(new_model.graph.input):
_input.type.tensor_type.shape.dim[0].dim_value = batch
for idx, _output in enumerate(new_model.graph.output):
_output.type.tensor_type.shape.dim[0].dim_value = batch
model_sim, success = simplify(new_model)
if not success:
print("simplify failed")
onnx.save(new_model, filepath.replace(".onnx", f"_b{batch}.onnx"))
else:
onnx.save(model_sim, filepath.replace(".onnx", f"_b{batch}.onnx"))