gather_points自定义tensorrt算子编写

部署运行你感兴趣的模型镜像

导出onnx模型

import torch
import numpy as np
import pointnet2_utils

    
class CustomModel(torch.nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        
    def forward(self, features, idx):
        return pointnet2_utils.gather_operation(features, idx)
    
    
model = CustomModel().cuda()
features = torch.randn(1, 3, 20000).cuda()  
idx = torch.randint(0, 20000, (1, 2048), dtype=torch.int32).cuda()
np.savetxt("features.txt", features.reshape(3, 20000).detach().cpu().numpy())
np.savetxt("id.txt", idx.reshape(1, 2048).detach().cpu().numpy())

torch.onnx.export(model, (features, idx), "gather_operation.onnx", opset_version=13)

其中pointnet2_utils来自https://github.com/erikwijmans/Pointnet2_PyTorch
导出的onnx模型结构如下:
在这里插入图片描述
在这里插入图片描述

编写tensorrt插件

采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建gatherPoints文件夹,添加下面文件:
gatherPoints.h

#ifndef TRT_GATHERPOINTS_PLUGIN_H
#define TRT_GATHERPOINTS_PLUGIN_H


#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/cuda_utils.h"

#include <vector>
#include <cstring>

namespace nvinfer1
{
namespace plugin
{
void gather_points_kernel_wrapper(int b, int c, int n, int npoints, const float *points, const int *idx,float *out, cudaStream_t stream);

class GatherPoints : public nvinfer1::IPluginV2DynamicExt 
{
public:
    GatherPoints();

    GatherPoints(void const* data, size_t length);

    ~GatherPoints() override;

    // 插件基本信息
    char const* getPluginType() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    int getNbOutputs() const noexcept override;

    // 输出维度计算
    nvinfer1::DimsExprs getOutputDimensions(int outputIndex,
		const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;

    // 初始化与销毁
    int initialize() noexcept override;
    void terminate() noexcept override;

    // 执行相关
    int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
		const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
    size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
		int nbOutputs) const noexcept override;

    // 数据类型与格式支持
    DataType getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept override;
    bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,int nbOutputs) noexcept override;

    // 配置插件
    void configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
		const  nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;

    // 序列化
    size_t getSerializationSize() const noexcept override;
    void serialize(void* buffer) const noexcept override;  

    // 其他接口
    nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
    void destroy() noexcept override;
    void setPluginNamespace(char const* libNamespace) noexcept override;
    char const* getPluginNamespace() const noexcept override;

private:
    int mNumSamples;  // 采样点数(从输入获取)
    std::string mPluginNamespace;
    Dims mInputDims;    // 点云输入维度 (B, N, 3)
    Dims mSampleDims;   // 采样点数输入维度(通常是标量或 (B,))
};

class GatherPointsCreator : public nvinfer1::IPluginCreator 
{
public:
    GatherPointsCreator();
    ~GatherPointsCreator() override = default;

    char const* getPluginName() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    PluginFieldCollection const* getFieldNames() noexcept override;
    IPluginV2* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override;
    IPluginV2* deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept override;
    void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
	const char* getPluginNamespace() const noexcept override;

private:
    static PluginFieldCollection mFC;
    static std::vector<PluginField> mPluginAttributes;
    std::string mNamespace;
};

} // namespace plugin
} // namespace nvinfer1

#endif // TRT_GATHERPOINTS_PLUGIN_H

gatherPoints.cpp

#include "gatherPoints.h"
#include "common/dimsHelpers.h"

using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::GatherPoints;
using nvinfer1::plugin::GatherPointsCreator;

// 插件实现
GatherPoints::GatherPoints() 
{

}

GatherPoints::GatherPoints(void const* data, size_t length) 
{

}

GatherPoints::~GatherPoints() {}

// 插件基本信息
char const* GatherPoints::getPluginType() const noexcept 
{ 
    //std::cout<<"getPluginType"<<std::endl;
    return "gather_points"; 
}

char const* GatherPoints::getPluginVersion() const noexcept 
{ 
    //std::cout<<"getPluginVersion"<<std::endl;
    return "1"; 
}

int GatherPoints::getNbOutputs() const noexcept 
{ 
    //std::cout<<"getNbOutputs"<<std::endl;
    return 1; 
}

nvinfer1::DimsExprs GatherPoints::getOutputDimensions(int outputIndex,
	const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept 
{
    //std::cout << "getOutputDimensions" << std::endl;
    // 验证输出索引和输入数量
    PLUGIN_ASSERT(outputIndex == 0 && nbInputs == 2);
    
    // 构建输出维度: (B, M)
    nvinfer1::DimsExprs outputDims;
    outputDims.nbDims = 3;
    
    // 第一个维度为批次大小 B (与输入保持一致)
    outputDims.d[0] = exprBuilder.constant(static_cast<int>(inputs[0].d[0]->getConstantValue()));
    outputDims.d[1] = exprBuilder.constant(static_cast<int>(inputs[0].d[1]->getConstantValue()));
    outputDims.d[2] = exprBuilder.constant(static_cast<int>(inputs[1].d[1]->getConstantValue()));
    return outputDims;
}

// 初始化
int GatherPoints::initialize() noexcept 
{ 
    return STATUS_SUCCESS; 
}

// 销毁资源
void GatherPoints::terminate() noexcept {}

// 执行核函数
int GatherPoints::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
	const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
    //std::cout << "enqueue" << std::endl;
    try 
    {
        // 校验输入维度数量(点云输入应为3维: [B, C, N])
        PLUGIN_ASSERT(inputDesc[0].dims.nbDims == 3);
        PLUGIN_ASSERT(inputDesc[1].dims.nbDims == 2);  // 索引输入应为2维: [B, M]
        PLUGIN_ASSERT(outputDesc[0].dims.nbDims == 3); // 输出应为3维: [B, C, M]

        // 校验数据类型
        PLUGIN_ASSERT(inputDesc[0].type == nvinfer1::DataType::kFLOAT);    // 点云数据为float
        PLUGIN_ASSERT(inputDesc[1].type == nvinfer1::DataType::kINT32);    // 索引为int32
        PLUGIN_ASSERT(outputDesc[0].type == nvinfer1::DataType::kFLOAT);   // 输出为float

        // 解析输入维度
        int B = inputDesc[0].dims.d[0];  // 批次大小
        int C = inputDesc[0].dims.d[1];  // 点云特征维度(如3表示x,y,z)
        int N = inputDesc[0].dims.d[2];  // 原始点数量
        int M = inputDesc[1].dims.d[1];  // 要提取的点数量

        // 绑定输入输出指针
        const float* points = static_cast<const float*>(inputs[0]);  // 输入点云: [B, C, N]
        const int* idx = static_cast<const int*>(inputs[1]);         // 索引: [B, M]
        float* out = static_cast<float*>(outputs[0]);                // 输出点云: [B, C, M]

        // 调用CUDA核函数执行gather操作
        gather_points_kernel_wrapper(B, C, N, M, points, idx, out, stream);

        return STATUS_SUCCESS;
    } 
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
    return -1;
}

// 工作空间大小
size_t GatherPoints::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
		int nbOutputs) const noexcept 
{ 
    return 0;
}

// 输出数据类型:索引为INT32
DataType GatherPoints::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept 
{
    //std::cout<<"getOutputDataType"<<std::endl;  
    PLUGIN_ASSERT(index == 0 && nbInputs == 2);
    return DataType::kFLOAT;
}

// 支持的格式:输入float32,输出int32,均为线性格式
bool GatherPoints::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{
    //std::cout << "supportsFormatCombination" << std::endl;

    // 插件有2个输入和1个输出
    PLUGIN_ASSERT(pos < nbInputs + nbOutputs);

    if (pos == 0)
    {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 1)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 2)
    {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }

    return false;
}

void GatherPoints::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
		const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
    //std::cout<<"configurePlugin"<<std::endl;
    try 
    {
        PLUGIN_ASSERT(nbInputs == 2 && nbOutputs == 1); // 确认2个输入和1个输出
        // 校验点云输入维度 (B, 3, N)
        PLUGIN_ASSERT(in[0].desc.dims.nbDims == 3 && in[0].desc.dims.d[1] == 3);
    } 
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
}

// 序列化:仅需保存采样点数
size_t GatherPoints::getSerializationSize() const noexcept 
{ 
    //std::cout<<"getSerializationSize"<<std::endl;
    return sizeof(int); 
}

void GatherPoints::serialize(void* buffer) const noexcept 
{
    //std::cout<<"serialize"<<std::endl;
}


// 克隆插件
nvinfer1::IPluginV2DynamicExt* GatherPoints::clone() const noexcept 
{
    //std::cout<<"clone"<<std::endl;
    try
    {
        return new GatherPoints();
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
    return nullptr;
}

void GatherPoints::destroy() noexcept 
{ 
    delete this; 
}

// 命名空间管理
void GatherPoints::setPluginNamespace(char const* pluginNamespace) noexcept 
{ 
    //std::cout<<"setPluginNamespace"<<std::endl;
    try
    {
        mPluginNamespace = pluginNamespace;
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
}

char const* GatherPoints::getPluginNamespace() const noexcept 
{ 
    //std::cout<<"getPluginNamespace"<<std::endl;
    return mPluginNamespace.c_str(); 
}

// 插件创建器实现
PluginFieldCollection GatherPointsCreator::mFC{};
std::vector<PluginField> GatherPointsCreator::mPluginAttributes;

GatherPointsCreator::GatherPointsCreator() 
{
    //std::cout<<"GatherPointsCreator"<<std::endl;
    mPluginAttributes.clear();
    mFC.nbFields = mPluginAttributes.size();
    mFC.fields = mPluginAttributes.data();
}

char const* GatherPointsCreator::getPluginName() const noexcept 
{ 
    //std::cout<<"getPluginName"<<std::endl;
    return "gather_points"; 
}

char const* GatherPointsCreator::getPluginVersion() const noexcept 
{ 
    //std::cout<<"getPluginVersion"<<std::endl;
    return "1"; 
}

PluginFieldCollection const* GatherPointsCreator::getFieldNames() noexcept 
{ 
    //std::cout<<"getFieldNames"<<std::endl;
    return &mFC; 
}

IPluginV2* GatherPointsCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept 
{
    //std::cout<<"createPlugin"<<std::endl;
    try 
    {
        return new GatherPoints();
    }
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
    return nullptr;
}

IPluginV2* GatherPointsCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept 
{
    //std::cout<<"deserializePlugin"<<std::endl;
    try
    {
        // This object will be deleted when the network is destroyed, which will
        // call Concat::destroy()
        IPluginV2Ext* plugin = new GatherPoints();
        plugin->setPluginNamespace(mNamespace.c_str());
        return plugin;
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
    return nullptr;
}

void GatherPointsCreator::setPluginNamespace(char const* libNamespace) noexcept
{  
	//std::cout<<"setPluginNamespace"<<std::endl;        
	mNamespace = libNamespace;
}

char const* GatherPointsCreator::getPluginNamespace() const noexcept
{    
	//std::cout<<"getPluginNamespace"<<std::endl;      
	return mNamespace.c_str();
}

gatherPoints.cu

#include <stdio.h>
#include <stdlib.h>

#include "NvInfer.h"
#include "gatherPoints.h"
#include <cuda_runtime.h>

namespace nvinfer1
{
namespace plugin
{
// input: points(b, c, n) idx(b, m)
// output: out(b, c, m)
__global__ void gather_points_kernel(int b, int c, int n, int m,
                                     const float *__restrict__ points,
                                     const int *__restrict__ idx,
                                     float *__restrict__ out) {
  for (int i = blockIdx.x; i < b; i += gridDim.x) {
    for (int l = blockIdx.y; l < c; l += gridDim.y) {
      for (int j = threadIdx.x; j < m; j += blockDim.x) {
        int a = idx[i * m + j];
        out[(i * c + l) * m + j] = points[(i * c + l) * n + a];
      }
    }
  }
}

void gather_points_kernel_wrapper(int b, int c, int n, int npoints,
                                  const float *points, const int *idx,
                                  float *out, cudaStream_t stream) {
  gather_points_kernel<<<dim3(b, c, 1), opt_n_threads(npoints), 0, stream>>>(b, c, n, npoints, points, idx, out);

  CUDA_CHECK_ERRORS();
}

// input: grad_out(b, c, m) idx(b, m)
// output: grad_points(b, c, n)
__global__ void gather_points_grad_kernel(int b, int c, int n, int m,
                                          const float *__restrict__ grad_out,
                                          const int *__restrict__ idx,
                                          float *__restrict__ grad_points) {
  for (int i = blockIdx.x; i < b; i += gridDim.x) {
    for (int l = blockIdx.y; l < c; l += gridDim.y) {
      for (int j = threadIdx.x; j < m; j += blockDim.x) {
        int a = idx[i * m + j];
        atomicAdd(grad_points + (i * c + l) * n + a,
                  grad_out[(i * c + l) * m + j]);
      }
    }
  }
}

void gather_points_grad_kernel_wrapper(int b, int c, int n, int npoints,
                                       const float *grad_out, const int *idx,
                                       float *grad_points, cudaStream_t stream) {
  gather_points_grad_kernel<<<dim3(b, c, 1), opt_n_threads(npoints), 0, stream>>>(
      b, c, n, npoints, grad_out, idx, grad_points);

  CUDA_CHECK_ERRORS();
}

}
}

CMakeLists.txt

file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)

在TensorRT/plugin/inferPlugin.cpp的开头添加

#include "gatherPoints/gatherPoints.h"

并在initLibNvInferPlugins函数中添加

 initializePlugin<nvinfer1::plugin::GatherPointsCreator>(logger, libNamespace);

在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加

gatherPoints

在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR、TRT_OUT_DIR,再重新编译tensorrt。

tensorrt推理测试

运行下面的命令把onnx 转为engine模型:

TensorRT-10.6.0.26/bin/trtexec --onnx=gather_operation.onnx --saveEngine=gather_operation.engine

编写python推理脚本:

import numpy as np
import tensorrt as trt
import common


logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("gather_operation.engine", "rb") as f, trt.Runtime(logger) as runtime:
    engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)

features = np.loadtxt("features.txt")
idx = np.loadtxt("idx.txt")
features = features.reshape(1, 3, 20000).astype(np.float32)
idx = idx.reshape(1, 2048).astype(np.int32)
np.copyto(inputs[0].host, features.ravel())
np.copyto(inputs[1].host, idx.ravel())

output = common.do_inference(context,engine=engine, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print(output)

您可能感兴趣的与本文相关的镜像

TensorRT-v8.6

TensorRT-v8.6

TensorRT

TensorRT 是NVIDIA 推出的用于深度学习推理加速的高性能推理引擎。它可以将深度学习模型优化并部署到NVIDIA GPU 上,实现低延迟、高吞吐量的推理过程。

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

给算法爸爸上香

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值