

# generate.py
import onnx_graphsurgeon as gs
import numpy as np
import onnx
# Computes Y = x0 + (a * x1 + b)
shape = (1, 3, 224, 224)
# Inputs
x0 = gs.Variable(name="x0", dtype=np.float32, shape=shape)
x1 = gs.Variable(name="x1", dtype=np.float32, shape=shape)
# Intermediate tensors
a = gs.Constant("a", values=np.ones(shape=shape, dtype=np.float32))
b = gs.Constant("b", values=np.ones(shape=shape, dtype=np.float32))
mul_out = gs.Variable(name="mul_out")
add_out = gs.Variable(name="add_out")
# Outputs
Y = gs.Variable(name="Y", dtype=np.float32, shape=shape)
nodes = [
# mul_out = a * x1
gs.Node(op="Mul", inputs=[a, x1], outputs=[mul_out]),
# add_out = mul_out + b
gs.Node(op="Add", inputs=[mul_out, b], outputs=[add_out]),
# Y = x0 + add
gs.Node(op="Add", inputs=[x0, add_out], outputs=[Y]),
]
graph = gs.Graph(nodes=nodes, inputs=[x0, x1], outputs=[Y])
onnx.save(gs.export_onnx(graph), "model.onnx")
# modify.py
import onnx_graphsurgeon as gs
import numpy as np
import onnx
graph = gs.import_onnx(onnx.load("model.onnx"))
# 1. Remove the `b` input of the add node
first_add = [node for node in graph.nodes if node.op == "Add"][0]
first_add.inputs = [inp for inp in first_add.inputs if inp.name != "b"]
# 2. Change the Add to a LeakyRelu
first_add.op = "LeakyRelu"
first_add.attrs["alpha"] = 0.02
# 3. Add an identity after the add node
identity_out = gs.Variable("identity_out", dtype=np.float32)
identity = gs.Node(op="Identity", inputs=first_add.outputs, outputs=[identity_out])
graph.nodes.append(identity)
# 4. Modify the graph output to be the identity output
graph.outputs = [identity_out]
graph.inputs = [graph.tensors()["x1"].to_variable(dtype=np.float32)]
# 5. Remove unused nodes/tensors, and topologically sort the graph
# ONNX requires nodes to be topologically sorted to be considered valid.
# Therefore, you should only need to sort the graph when you have added new nodes out-of-order.
# In this case, the identity node is already in the correct spot (it is the last node,
# and was appended to the end of the list), but to be on the safer side, we can sort anyway.
graph.cleanup().toposort()
onnx.save(gs.export_onnx(graph), "modified.onnx")
该博客介绍了如何使用onnx_graphsurgeon库修改ONNX模型。首先,它展示了创建一个简单的加法模型,然后通过移除加法节点的输入并替换为LeakyReLU操作来修改模型。最后,添加了一个Identity节点,并更新了输出,确保模型仍能正确运行。这个过程展示了ONNX模型的动态调整和优化。
4617

被折叠的 条评论
为什么被折叠?



