mmpose hrnet python tensorrt推理

代码分为两个部分:

  • Tensorrt.py:Python tensorrt推理接口
  • trt_model:mmpose hrnet推理模型构建、前后处理

有两个注意点:
(1)注意图像预处理,标准化的参数值
(2)注意Opencv读取图像时,要将BGR转换为RGB格式

Tensorrt.py

import torch
import tensorrt as trt 
import os
from typing import Union, Optional, Sequence, Dict


class TrtModelMMPose(torch.nn.Module):
    def __init__(self, engine: Union[str, trt.ICudaEngine], output_names: Optional[Sequence[str]] = None) -> None:
        super().__init__()
        self.engine = engine
        if isinstance(self.engine, str):
            with trt.Logger() as logger, trt.Runtime(logger) as runtime:
                with open(self.engine, mode='rb') as f:
                    engine_bytes = f.read()
                self.engine = runtime.deserialize_cuda_engine(engine_bytes)
        self.context = self.engine.create_execution_context()
        names = [_ for _ in self.engine]
        input_names = list(filter(self.engine.binding_is_input, names))
        self._input_names = input_names
        self._output_names = output_names

        if self._output_names is None:
            output_names = list(set(names) - set(input_names))
            self._output_names = output_names

    def forward(self, inputs: Dict[str, torch.Tensor]):
        assert self._input_names is not None
        assert self._output_names is not None
        bindings = [None] * (len(self._input_names) + len(self._output_names))
        profile_id = 0
        for input_name, input_tensor in inputs.items():
            # check if input shape is valid
            input_name = 'input'
            profile = self.engine.get_profile_shape(profile_id, input_name)
            assert 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值