import torch
import onnx
from onnx import helper, TensorProto
import numpy as np
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.convs1 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
self.convs2 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
self.convs3 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
self.convs4 = torch.nn.Sequential(torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3),
torch.nn.Conv2d(3, 3, 3))
def forward(self, x):
x = self.convs1(x)
x1 = self.convs2(x)
x2 = self.convs3(x)
x = x1 + x2
x = self.convs4(x)
return x
model = Model()
input = torch.randn(1, 3, 20, 20)
torch.onnx.export(model, input, 'whole_model.onnx')
onnx.utils.extract_model('/home/zhengwei/my_jupyter/lab8-lpr/Models/lprnet.onnx', 'partial_model.onnx', ['input.1'], ['/Div_output_0'])
model = onnx.load('partial_model.onnx')
div_node = None
for i in model.graph.node:
if i.op_type == "Div":
div_node = i
conv_w_weight = np.random.random_sample((1, 64, 1, 1))
conv_w = helper.make_tensor("conv_w", TensorProto.FLOAT, (1, 64, 1, 1), conv_w_weight)
conv_node = onnx.helper.make_node(
"Conv",
inputs=[div_node.output[0], "conv_w"],
outputs=["conv_node_output"],
kernel_shape=[1, 1],
# Default values for other attributes: strides=[1, 1], dilations=[1, 1], groups=1
)
model.graph.initializer.append(conv_w)
model.graph.node.append(conv_node)
new_output = onnx.helper.make_tensor_value_info('conv_node_output', onnx.TensorProto.FLOAT, [1, 1, 4, 18])
model.graph.output.extend([new_output])
model.graph.output.remove(model.graph.output[0])
onnx.save(model, 'tmp.onnx')
onnx截取子模型,并修改模型的输出
最新推荐文章于 2024-10-17 16:43:54 发布