Pytorch模型转换,pth->onnx->trt(TensorRT engine)

本文档详细解析了在将PyTorch模型(pth格式)转换为ONNX格式过程中遇到的常见问题,包括模型输入输出不支持字典类型、tensor.size(dim)操作导致的Gather[0]问题及求余算符在导出ONNX时不支持等,并提供了相应的解决策略。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

pth->onnx常见问题


##模型输入输出不支持字典
在使用torch.onnx导出onnx格式时,模型的输入和输出都不支持字典类型的数据结构。


**解决方法:**
此时,可以将字典数据结构换为torch.onnx支持的列表或者元组。例如:
heads {'hm': 1, 'wh': 2, 'hps': 34, 'reg': 2, 'hm_hp': 17, 'hp_offset': 2}
Centerpose中的字典在导出onnx时将会报错,可以将字典拆为两个列表,然后模型代码中做相应修改。


##tensor.size(dim)操作导致的转出的onnx操作中带有Gather[0]
如果在网络中存在类似tensor.size(dim)或者tensor.shape[1]这种索引操作,就会在转出的onnx模型中产生出Gather[0]操作,而这个操作,在TensorRT中是不支持的。


**解决方法:**
尽量避免类似的操作,代码中不要出现与网络操作无关的tensor,一些尺寸之类的常数全部写成普通变量。


##求余算符在导出onnx时不支持
曾经在导出求余算符‘%’时出现如下报错,是因为此运算符为ATen操作符,但是未在torch/onnx/symbolic_opset9.py中定义。
/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py:562: UserWarning: ONNX export failed on ATen operator remainder because torch.onnx.symbolic_opset9.remainder does not exist
  .format(op_name, opset_version, op_name))
Traceback (most recent call last):
  File "transfer_centernet_dlav0_to_onnx2.py", line 65, in <module>
    onnx_module = torch.onnx.export(model,img,'multipose_dlav0_1x_modellast_process.onnx',verbose=True)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 132, in export
    strip_doc_string, dynamic_axes)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 64, in export
    example_outputs=example_outputs, strip_doc_string=strip_doc_string, dynamic_axes=dynamic_axes)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 329, in _export
    _retain_param_name, do_constant_folding)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 225, in _model_to_graph
    _disable_torch_constant_prop=_disable_torch_constant_prop)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 127, in _optimize_graph
    graph = torch._C._jit_pass_onnx(graph, operator_export_type)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/__init__.py", line 163, in _run_symbolic_function
    return utils._run_symbolic_function(*args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/utils.py", line 563, in _run_symbolic_function
    op_fn = sym_registry.get_registered_op(op_name, '', opset_version)
  File "/usr/local/lib/python3.6/dist-packages/torch/onnx/symbolic_registry.py", line 91, in get_registered_op
    return _registry[(domain, version)][opname]
KeyError: 'remainder'


**解决方法:**
在torch/onnx/symbolic_opset9.py中补上自己的remainder的定义,添加的代码如下
@parse_args( 'v', 'v')
def remainder(g,input,division):
    return g.op("Remainder",input,division)
 
PyTorch模型可以通过以下步骤转换TensorRT: 1. 定义PyTorch模型并加载预训练权重。 2. 使用ONNXPyTorch模型转换ONNX格式。 3. 使用TensorRTONNX解析器将ONNX模型转换TensorRT引擎。 4. 在TensorRT引擎上执行推理。 下面是一个简单的示例代码: ```python import torch import tensorrt as trt import numpy as np import onnx import onnx_tensorrt.backend as backend # 定义PyTorch模型并加载预训练权重 class MyModel(torch.nn.Module): def __init__(self): super(MyModel, self).__init__() self.fc1 = torch.nn.Linear(10, 5) self.fc2 = torch.nn.Linear(5, 2) def forward(self, x): x = self.fc1(x) x = torch.nn.functional.relu(x) x = self.fc2(x) return x model = MyModel() model.load_state_dict(torch.load('model.pth')) # 使用ONNXPyTorch模型转换ONNX格式 dummy_input = torch.randn(1, 10) input_names = ['input'] output_names = ['output'] torch.onnx.export(model, dummy_input, 'model.onnx', verbose=False, input_names=input_names, output_names=output_names) # 使用TensorRTONNX解析器将ONNX模型转换TensorRT引擎 onnx_model = onnx.load('model.onnx') engine = backend.prepare(onnx_model, device='CUDA:0') # 在TensorRT引擎上执行推理 input_data = np.random.randn(1, 10).astype(np.float32) output_data = engine.run(input_data)[0] ``` 在上面的代码中,我们首先定义了一个简单的PyTorch模型,并加载了预训练权重。然后,我们使用`torch.onnx.export()`将PyTorch模型转换ONNX格式,并定义了输入和输出的名称。接着,我们使用ONNX解析器将ONNX模型转换TensorRT引擎。最后,我们使用TensorRT引擎执行推理,得到输出数据。
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值