pytorch版本导出onnx的代码大致雷同,op12,定义input,outname,唯一注意的是最后一层argmax不要onnx导出。这个op用onnx导出直接导致推理结果全为0.
如果还有提示pad的错误,可以直接修改inputsize为7的整数倍,去掉pad部分的代码。
tensorrt版本本次选用8.2.3,不需要自定义plugin,layernorm ,gelu都能正确解析推理。
argmax的后处理直接用c++重写,适用于分割。
本文介绍了如何将PyTorch模型导出为ONNX格式,特别提到argmax操作不建议直接导出,因为会导致推理结果出错。解决方法是使用C++重写argmax的后处理步骤。此外,针对TensorRT 8.2.3,无需自定义插件,layernorm和gelu等操作能够成功解析并进行推理。
pytorch版本导出onnx的代码大致雷同,op12,定义input,outname,唯一注意的是最后一层argmax不要onnx导出。这个op用onnx导出直接导致推理结果全为0.
如果还有提示pad的错误,可以直接修改inputsize为7的整数倍,去掉pad部分的代码。
tensorrt版本本次选用8.2.3,不需要自定义plugin,layernorm ,gelu都能正确解析推理。
argmax的后处理直接用c++重写,适用于分割。
您可能感兴趣的与本文相关的镜像
PyTorch 2.5
PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理
2876
4128