


# generate.py
import onnx_graphsurgeon as gs
import numpy as np
import onnx
# Computes outputs = input + ((a + b) + d)
shape = (1, 3)
# Inputs
input = gs.Variable("input", shape=shape, dtype=np.float32)
# Intermediate tensors
a = gs.Constant("a", values=np.ones(shape=shape, dtype=np.float32))
b = gs.Constant("b", values=np.ones(shape=shape, dtype=np.float32))
c = gs.Variable("c")
d = gs.Constant("d", values=np.ones(shape=shape, dtype=np.float32))
e = gs.Variable("e")
# Outputs
output = gs.Variable("output", shape=shape, dtype=np.float32)
nodes = [
# c = (a + b)
gs.Node("Add", inputs=[a, b], outputs=[c]),
# e = (c + d)
gs.Node("Add", inputs=[c, d], outputs=[e]),
# output = input + e
gs.Node("Add", inputs=[input, e], outputs=[output]),
]
graph = gs.Graph(nodes=nodes, inputs=[input], outputs=[output])
onnx.save(gs.export_onnx(graph), "model.onnx")
# fold.py
import onnx_graphsurgeon as gs
import onnx
print("Graph.fold_constants Help:\n{}".format(gs.Graph.fold_constants.__doc__))
graph = gs.import_onnx(onnx.load("model.onnx"))
# Fold constants in the graph using ONNX Runtime. This will replace
# expressions that can be evaluated prior to runtime with constant tensors.
# The `fold_constants()` function will not, however, remove the nodes that
# it replaced - it simply changes the inputs of subsequent nodes.
# To remove these unused nodes, we can follow up `fold_constants()` with `cleanup()`
graph.fold_constants().cleanup()
onnx.save(gs.export_onnx(graph), "folded.onnx")
这篇博客介绍了如何使用onnx_graphsurgeon库对ONNX模型进行常量折叠和节点清理。首先,创建了一个计算`(input + (a + b)) + d`的模型,然后通过`fold_constants()`函数将可以提前计算的表达式转换为常量张量,再用`cleanup()`移除无用节点,最后保存优化后的模型。
2157

被折叠的 条评论
为什么被折叠?



