深度学习的onnx模型插入新节点构建新模型

在这里插入图片描述

import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend as backend

model = onnx.load('test.onnx')
node = model.graph.node
graph = model.graph
 
# 1.2搜索目标节点
# for i in range(len(node)):
#     if node[i].op_type == 'Conv':
#         node_rise = node[i]
#         if node_rise.output[0] == '203':
#             print(i)
# print(node[159])

new_node_0 = onnx.helper.make_node(
    "Mul",
    inputs=["input_image","1"],
    outputs=["mutiply"],
)

mutiply_node = onnx.helper.make_node(
    "Constant",
    inputs=[],
    outputs=["1"],
    value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [], [2.0])
)

new_node_1 = onnx.helper.make_node(
    "Add",
    inputs=["mutiply","2"],
    outputs=["add"],
)

add_node = onnx.helper.make_node(
    "Constant",
    inputs=[],
    outputs=["2"],
    value=onnx.helper.make_tensor('value', onnx.TensorProto.FLOAT, [], [-1.0])
)

#删除老节点 
old_squeeze_node = model.graph.node[0]
old_squeeze_node.input[0] = "add"
model.graph.node.remove(old_squeeze_node)

graph.node.insert(0, mutiply_node)
graph.node.insert(1, new_node_0)
graph.node.insert(2, add_node)
graph.node.insert(3, new_node_1)
graph.node.insert(4, old_squeeze_node)
onnx.checker.check_model(model)
onnx.save(model, 'out.onnx')

# session = onnxruntime.InferenceSession("out.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# out = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: np.ones([1, 1, 128, 128], dtype=np.float32)})[0]
# print(out)

print(onnxruntime.get_device())
rt = backend.prepare(model, "CPU")
out = rt.run(np.ones([1, 1, 128, 128], dtype=np.float32))
print(out)

在这里插入图片描述

第二种使用可供训练的初始化参数

import numpy as np
import onnx
import onnxruntime
import onnxruntime.backend as backend

model = onnx.load('test.onnx')
node = model.graph.node
graph = model.graph
 
# 1.2搜索目标节点
# for i in range(len(node)):
#     if node[i].op_type == 'Conv':
#         node_rise = node[i]
#         if node_rise.output[0] == '203':
#             print(i)
# print(node[159])

mutiply_node = onnx.helper.make_tensor(name='1',
                                      data_type=onnx.TensorProto.FLOAT,
                                      dims= [1],
                                      vals = np.array([2.0], dtype=np.float32)
                                        )

graph.initializer.append(mutiply_node)

new_node_0 = onnx.helper.make_node(
    "Mul",
    inputs=["input_image","1"],
    outputs=["mutiply"],
)

add_node = onnx.helper.make_tensor(name='2',
                                      data_type=onnx.TensorProto.FLOAT,
                                      dims= [1],
                                      vals = np.array([-1.], dtype=np.float32)
                                        )

graph.initializer.append(add_node)

new_node_1 = onnx.helper.make_node(
    "Add",
    inputs=["mutiply","2"],
    outputs=["add"],
)

#删除老节点 
old_squeeze_node = model.graph.node[0]
old_squeeze_node.input[0] = "add"
model.graph.node.remove(old_squeeze_node)

graph.node.insert(0, new_node_0)
graph.node.insert(1, new_node_1)
graph.node.insert(2, old_squeeze_node)
onnx.checker.check_model(model)
onnx.save(model, 'out.onnx')

# session = onnxruntime.InferenceSession("out.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
# out = session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: np.ones([1, 1, 128, 128], dtype=np.float32)})[0]
# print(out)

print(onnxruntime.get_device())
rt = backend.prepare(model, "CPU")
out = rt.run(np.ones([1, 1, 128, 128], dtype=np.float32))
print(out)

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

誓天断发

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值