代码分为两个部分:
- Tensorrt.py:Python tensorrt推理接口
- trt_model:mmpose hrnet推理模型构建、前后处理
有两个注意点:
(1)注意图像预处理,标准化的参数值
(2)注意Opencv读取图像时,要将BGR转换为RGB格式
Tensorrt.py
import torch
import tensorrt as trt
import os
from typing import Union, Optional, Sequence, Dict
class TrtModelMMPose(torch.nn.Module):
def __init__(self, engine: Union[str, trt.ICudaEngine], output_names: Optional[Sequence[str]] = None) -> None:
super().__init__()
self.engine = engine
if isinstance(self.engine, str):
with trt.Logger() as logger, trt.Runtime(logger) as runtime:
with open(self.engine, mode='rb') as f:
engine_bytes = f.read()
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
self.context = self.engine.create_execution_context()
names = [_ for _ in self.engine]
input_names = list(filter(self.engine.binding_is_input, names))
self._input_names = input_names
self._output_names = output_names
if self._output_names is None:
output_names = list(set(names) - set(input_names))
self._output_names = output_names
def forward(self, inputs: Dict[str, torch.Tensor]):
assert self._input_names is not None
assert self._output_names is not None
bindings = [None] * (len(self._input_names) + len(self._output_names))
profile_id = 0
for input_name, input_tensor in inputs.items():
# check if input shape is valid
input_name = 'input'
profile = self.engine.get_profile_shape(profile_id, input_name)
assert