import onnx
import onnxruntime as rt
import os
import numpy as np
import argparse
class fix_dim_tools:
def __init__(self, model_path, inputs_shape, inputs_dtype):
assert os.path.exists(model_path), "{} not exists".format(model_path)
if inputs_dtype is None:
print('inputs_dtype is not define, use float for all inputs node')
inputs_dtype = ['float']*len(inputs_shape)
else:
assert len(inputs_shape)==len(inputs_dtype), "inputs shape list should have same length as inputs_dtype"
model = onnx.load(model_path)
self.model = model
self.model_path = model_path
self.inputs_shape = inputs_shape
self.inputs_dtype = inputs_dtype
self.inputs_shape_dict = {
}
self.inputs_type_dict = {
}
self.outputs_shape_dict = {
}
def check_dynamic_input(self):
# check dynamic input and get real input shape
inputs_number = len(self.model.graph.input)
assert inputs_number==len(self.inputs_shape),"model has {} inputs, but {} inputs_shape was given, not match".format(inputs_number,len(self.inputs_shape)
onnx动态模型转静态模型
最新推荐文章于 2024-11-13 10:42:38 发布