import torch
import torch.onnx
import onnxruntime as ort
import numpy as np
from onnxruntime_extensions import (
onnx_op, PyCustomOpDef as PyOp, make_onnx_model, PyOrtFunction, get_library_path as _get_library_path)
torch.ops.load_library("/home/pliu3/AI-preprocess/Picasso-onnx-debug/onnx/build/libmesh_decimation.so")
def test_Inversion():
# 定义自定义算子
class SSimplify(torch.autograd.Function):
@staticmethod
def forward(ctx, vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove):
vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut = torch.ops.mesh_decimate.simplify(vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove)
return vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut
@staticmethod
def symbolic(g, vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove):
# 关键:域名必须与运行时注册一致("custom")
node = g.op("ai.onnx.contrib::Simplify", vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove, outputs=7)
node[0].setType(vertexIn.type())
node[1].setType(faceIn.type())
node[2].setType(faceIn.type())
node[3].setType(faceIn.type())
node[4].setType(faceIn.type())
node[5].setType(faceIn.type())
node[6].setType(faceIn.type())
return node
# 封装为模型
class CustomModel(torch.nn.Module):
def forward(self, vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove):
return SSimplify.apply(vertexIn, faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove)
# 导出 ONNX 模型
model = CustomModel()
vertexIn = torch.load("../vertexIn.pt")
faceIn = torch.load("../faceIn.pt")
geometryIn = torch.load("../geometry.pt")
nvIn_cumsum = torch.load("../nvIn_cumsum.pt")
mfIn_cumsum = torch.load("../mfIn_cumsum.pt")
nv2Remove = torch.load("../nv2Remove.pt")
#vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut
torch.onnx.export(
model,
(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove),
"simplify.onnx",
input_names=["vertexIn", "faceIn", "geometryIn", "nvIn_cumsum", "mfIn_cumsum", "nv2Remove"],
output_names=["vertexOut", "faceOut", "isDegenerate", "repOut", "mapOut", "nvOut", "mfOut"],
opset_version=17
)
@onnx_op(op_type="Simplify", domain="ai.onnx.contrib",
inputs=[PyOp.dt_float, PyOp.dt_int32, PyOp.dt_float,
PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32],
outputs=[PyOp.dt_float, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32, PyOp.dt_int32])
def fimplify_forward(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove): # onnxruntime的自定义算子接受的是numpy的输入 因此需要将numpy转为torch
vertexIn_tensor = torch.tensor(vertexIn)
faceIn_tensor = torch.tensor(faceIn, dtype=torch.int32)
geometryIn_tensor = torch.tensor(geometryIn)
nvIn_cumsum_tensor = torch.tensor(nvIn_cumsum, dtype=torch.int32)
mfIn_cumsum_tensor = torch.tensor(mfIn_cumsum, dtype=torch.int32)
nv2Remove_tensor = torch.tensor(nv2Remove, dtype=torch.int32)
vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut = torch.ops.mesh_decimate.simplify(vertexIn_tensor,
faceIn_tensor, geometryIn_tensor, nvIn_cumsum_tensor, mfIn_cumsum_tensor, nv2Remove_tensor)
return vertexOut.numpy(), faceOut.numpy(), isDegenerate.numpy(), repOut.numpy(), mapOut.numpy(), nvOut.numpy(), mfOut.numpy()
# 加载模型时传递 SessionOptions
so = ort.SessionOptions()
so.register_custom_ops_library(_get_library_path())
session = ort.InferenceSession("./simplify.onnx", so, providers=['CPUExecutionProvider'])
vertexIn_in = vertexIn.numpy()
faceIn_in = faceIn.numpy()
geometryIn_in = geometryIn.numpy()
nvIn_cumsum_in = nvIn_cumsum.numpy()
mfIn_cumsum_in = mfIn_cumsum.numpy()
nv2Remove_in = nv2Remove.numpy()
# 运行推理
ort_output = session.run(
output_names=["vertexOut", "faceOut", "isDegenerate", "repOut", "mapOut", "nvOut", "mfOut"],
input_feed={"vertexIn": vertexIn_in,
"faceIn": faceIn_in,
"geometryIn": geometryIn_in,
"nvIn_cumsum": nvIn_cumsum_in,
"mfIn_cumsum": mfIn_cumsum_in,
"nv2Remove": nv2Remove_in,
}
)[0]
# pytorch_output = model(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove)
print("自定义算子输出: ", ort_output[0].shape)
vertexOut, faceOut, isDegenerate, repOut, mapOut, nvOut, mfOut = model(vertexIn , faceIn, geometryIn, nvIn_cumsum, mfIn_cumsum, nv2Remove)
print(vertexOut.shape)
我这段代码写的对吗
为什么运行后
2025-06-14 18:50:50.964957476 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {3738} for output isDegenerate
2025-06-14 18:50:50.965330776 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1865} for output repOut
2025-06-14 18:50:50.965528976 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1865} for output mapOut
2025-06-14 18:50:50.965745576 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1} for output nvOut
2025-06-14 18:50:50.965772376 [W:onnxruntime:, execution_frame.cc:870 VerifyOutputSizes] Expected shape from model of {-1,-1} does not match actual shape of {1} for output mfOut
自定义算子输出: (3,)
torch.Size([1865, 3])