最近使用onnx时,想把所有的节点的信息和权重参数显示出来,找了一下没找到类似的源码,官方介绍的pythonAPI都是些什么加载,保存,转换之类的,没有详细介绍怎么使用onnx分析模型的,只好自己写一个。
其实很简单,我只列些最基本的,具体分析还得看个人的需要,
import onnx
if __name__ == '__main__':
model_path = r'F:\model_float32.onnx'
model = onnx.load(model_path)
nodes = model.graph.node
nodnum = len(nodes) # 205
for nid in range(nodnum):
if (nodes[nid].output[0] == 'stride_32'):
print('Found stride_32: index = ', nid)
else:
print(nodes[nid].output)
inits = model.graph.initializer
ininum = len(inits) # 124
for iid in range(ininum):
el = inits[iid]
print('name:', el.name, ' dtype:', el.data_type, ' dim:', el.dims) # el.raw_data for weights and biases
print(model.graph.output) # display all the output nodes