knnQuery自定义tensorrt算子编写

部署运行你感兴趣的模型镜像

导出onnx模型

import torch
import numpy as np
import pointops
    
    
class CustomModel(torch.nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        
    def forward(self, nsample, xyz, new_xyz, offset, new_offset):
        tmp = pointops.knnquery(nsample, xyz, new_xyz, offset, new_offset)
        return tmp
    

model = CustomModel().cuda()
nsample = [8]   
xyz = torch.randn(47358, 3).cuda()  
new_xyz = torch.randn(47358, 3).cuda()  
offset = torch.tensor([47358]).cuda() .to(torch.int32)
new_offset = torch.tensor([47358]).cuda() .to(torch.int32)
np.savetxt("xyz.txt", xyz.reshape(47358, 3).detach().cpu().numpy())
np.savetxt("new_xyz.txt", new_xyz.reshape(47358, 3).detach().cpu().numpy())

torch.onnx.export(model, (nsample, xyz, new_xyz, offset, new_offset), "knnquery.onnx", opset_version=13)

导出onnx模型结构如下:在这里插入图片描述
在这里插入图片描述

编写tensorrt插件

采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建knnQuery文件夹,添加下面文件:
knnQuery.h

#ifndef TRT_KNNQUERY_PLUGIN_H
#define TRT_KNNQUERY_PLUGIN_H


#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/cuda_utils.h"

#include <vector>
#include <cstring>

namespace nvinfer1
{
namespace plugin
{

void knnquery_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2, cudaStream_t stream);

class knnQuery : public nvinfer1::IPluginV2DynamicExt 
{
public:
    knnQuery(int sample);

    knnQuery(void const* data, size_t length);

    ~knnQuery() override;

    // 插件基本信息
    char const* getPluginType() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    int getNbOutputs() const noexcept override;

    // 输出维度计算
    nvinfer1::DimsExprs getOutputDimensions(int outputIndex,
		const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;

    // 初始化与销毁
    int initialize() noexcept override;
    void terminate() noexcept override;

    // 执行相关
    int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
		const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
    size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
		int nbOutputs) const noexcept override;

    // 数据类型与格式支持
    DataType getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept override;
    bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,int nbOutputs) noexcept override;

    // 配置插件
    void configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
		const  nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;

    // 序列化
    size_t getSerializationSize() const noexcept override;
    void serialize(void* buffer) const noexcept override;  

    // 其他接口
    nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
    void destroy() noexcept override;
    void setPluginNamespace(char const* libNamespace) noexcept override;
    char const* getPluginNamespace() const noexcept override;

private:
    int mSample;  
    std::string mPluginNamespace;
    Dims mInputDims;    
    Dims mSampleDims;   
};

class knnQueryCreator : public nvinfer1::IPluginCreator 
{
public:
    knnQueryCreator();
    ~knnQueryCreator() override = default;

    char const* getPluginName() const noexcept override;
    char const* getPluginVersion() const noexcept override;
    PluginFieldCollection const* getFieldNames() noexcept override;
    IPluginV2* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override;
    IPluginV2* deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept override;
    void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
	const char* getPluginNamespace() const noexcept override;

private:
    static PluginFieldCollection mFC;
    static std::vector<PluginField> mPluginAttributes;
    std::string mNamespace;
};

} // namespace plugin
} // namespace nvinfer1

#endif // TRT_KNNQUERY_PLUGIN_H

knnQuery.cpp

#include "knnQuery.h"
#include "common/dimsHelpers.h"

using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::knnQuery;
using nvinfer1::plugin::knnQueryCreator;

// 插件实现
knnQuery::knnQuery(int sample) : mSample(sample)
{
    //std::cout<<"knnQuery"<<std::endl;
}

knnQuery::knnQuery(void const* data, size_t length) 
{
    mSample = *static_cast<int const*>(data);
}

knnQuery::~knnQuery() {}

// 插件基本信息
char const* knnQuery::getPluginType() const noexcept 
{ 
    //std::cout<<"getPluginType"<<std::endl;
    return "knnquery"; 
}

char const* knnQuery::getPluginVersion() const noexcept 
{ 
    //std::cout<<"getPluginVersion"<<std::endl;
    return "1"; 
}

int knnQuery::getNbOutputs() const noexcept 
{ 
    //std::cout<<"getNbOutputs"<<std::endl;
    return 2; 
}

nvinfer1::DimsExprs knnQuery::getOutputDimensions(int outputIndex,
	const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept 
{
    //std::cout << "getOutputDimensions" << std::endl;
    // 验证输出索引和输入数量
    PLUGIN_ASSERT(nbInputs == 5);
    
    // 构建输出维度: (B, M)
    nvinfer1::DimsExprs outputDims;
    outputDims.nbDims = 2;
    
    // 第一个维度为批次大小 B (与输入保持一致)
    outputDims.d[0] = exprBuilder.constant(static_cast<int>(inputs[2].d[0]->getConstantValue()));
    outputDims.d[1] = exprBuilder.constant(mSample);
    return outputDims;
}

// 初始化
int knnQuery::initialize() noexcept 
{ 
    return STATUS_SUCCESS; 
}

// 销毁资源
void knnQuery::terminate() noexcept {}

// 执行核函数
int knnQuery::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
	const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
    //std::cout << "enqueue" << std::endl;
    try 
    {       
        // 验证输入维度
        PLUGIN_ASSERT(inputDesc[0].dims.nbDims == 0);       // nsample (标量)
        PLUGIN_ASSERT(inputDesc[1].dims.nbDims == 2);       // xyz: (总点数, 3)
        PLUGIN_ASSERT(inputDesc[2].dims.nbDims == 2);       // new_xyz: (总查询点, 3)
        PLUGIN_ASSERT(inputDesc[3].dims.nbDims == 1);       // offset: 原始点云批次偏移
        PLUGIN_ASSERT(inputDesc[4].dims.nbDims == 1);       // new_offset:  查询点批次偏移
        PLUGIN_ASSERT(inputDesc[1].dims.d[1] == 3);         // xyz为3D点
        PLUGIN_ASSERT(inputDesc[2].dims.d[1] == 3);         // new_xyz为3D点

        // 验证数据类型
        PLUGIN_ASSERT(inputDesc[0].type == nvinfer1::DataType::kINT32);   // nsample类型
        PLUGIN_ASSERT(inputDesc[1].type == nvinfer1::DataType::kFLOAT);   // xyz类型
        PLUGIN_ASSERT(inputDesc[2].type == nvinfer1::DataType::kFLOAT);   // new_xyz类型
        PLUGIN_ASSERT(inputDesc[3].type == nvinfer1::DataType::kINT32);   // offset类型
        PLUGIN_ASSERT(inputDesc[4].type == nvinfer1::DataType::kINT32);   // new_offset类型
        PLUGIN_ASSERT(outputDesc[0].type == nvinfer1::DataType::kINT32);  // 输出索引类型
        PLUGIN_ASSERT(outputDesc[1].type == nvinfer1::DataType::kFLOAT);  // 输出索引类型

        // 绑定输入数据指针
        const float* xyz = static_cast<const float*>(inputs[1]);
        const float* newXyz;
        if(inputs[2] == nullptr)
            newXyz = xyz;
        else
            newXyz = static_cast<const float*>(inputs[2]);
        const int* offset = static_cast<const int*>(inputs[3]);
        const int* newOffset = static_cast<const int*>(inputs[4]);
        int m = inputDesc[2].dims.d[0];
        //std::cout<<"--------------"<<mSample<<std::endl;

        // 绑定输出数据指针
        int* idx = static_cast<int*>(outputs[0]);       // 近邻索引输出: (m, nsample)
        float* dist2 = static_cast<float*>(outputs[1]); // 近邻距离平方输出: (m, nsample)
        cudaMemset(idx, 0, m * mSample * sizeof(int));
        cudaMemset(dist2, 0, m * mSample * sizeof(float));

        // 调用CUDA核函数执行KNN查询
        knnquery_cuda_launcher(
            m,              // 总查询点数量
            mSample,        // 近邻数量k
            xyz,            // 原始点云数据
            newXyz,         // 查询点数据
            offset,         // 原始点云批次偏移
            newOffset,      // 查询点批次偏移
            idx,            // 输出索引
            dist2,          // 输出距离平方
            stream          // CUDA流
        );

        return STATUS_SUCCESS;
    } 
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
    return -1;
}

// 工作空间大小
size_t knnQuery::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
		int nbOutputs) const noexcept 
{ 
    return 0;
}

// 输出数据类型:索引为INT32
DataType knnQuery::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept 
{
    //std::cout<<"getOutputDataType"<<std::endl;  
    PLUGIN_ASSERT(nbInputs == 5);
    if(index==0)    return DataType::kINT32;
    if(index==1)    return DataType::kFLOAT;
}


bool knnQuery::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{
    //std::cout << "supportsFormatCombination" << std::endl;

    PLUGIN_ASSERT(pos < nbInputs + nbOutputs);

    if (pos == 0)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 1)
    {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 2)
    {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 3)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 4)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 5)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 6)
    {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }

    return false;
}

void knnQuery::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
		const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
    //std::cout<<"configurePlugin"<<std::endl;
    try 
    {
        PLUGIN_ASSERT(nbInputs == 5 && nbOutputs == 2); 
    } 
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
}

// 序列化:仅需保存采样点数
size_t knnQuery::getSerializationSize() const noexcept 
{ 
    //std::cout<<"getSerializationSize"<<std::endl;
    return sizeof(int); 
}

void knnQuery::serialize(void* buffer) const noexcept 
{
    //std::cout<<"serialize"<<std::endl;
    //memcpy(buffer, &mNumSamples, sizeof(int));  // 此处参数为void*,与修正后的头文件匹配
    *static_cast<int*>(buffer) = mSample;
}


// 克隆插件
nvinfer1::IPluginV2DynamicExt* knnQuery::clone() const noexcept 
{
    //std::cout<<"clone"<<std::endl;
    try
    {
        return new knnQuery(mSample);
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
    return nullptr;
}

void knnQuery::destroy() noexcept 
{ 
    delete this; 
}

// 命名空间管理
void knnQuery::setPluginNamespace(char const* pluginNamespace) noexcept 
{ 
    //std::cout<<"setPluginNamespace"<<std::endl;
    try
    {
        mPluginNamespace = pluginNamespace;
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
}

char const* knnQuery::getPluginNamespace() const noexcept 
{ 
    //std::cout<<"getPluginNamespace"<<std::endl;
    return mPluginNamespace.c_str(); 
}

// 插件创建器实现
PluginFieldCollection knnQueryCreator::mFC{};
std::vector<PluginField> knnQueryCreator::mPluginAttributes;

knnQueryCreator::knnQueryCreator() 
{
    //std::cout<<"knnQueryCreator"<<std::endl;
    mPluginAttributes.clear();
    mPluginAttributes.emplace_back(nvinfer1::PluginField("attr", nullptr, nvinfer1::PluginFieldType::kINT32));
    mFC.nbFields = mPluginAttributes.size();
    mFC.fields = mPluginAttributes.data();
}

char const* knnQueryCreator::getPluginName() const noexcept 
{ 
    //std::cout<<"getPluginName"<<std::endl;
    return "knnquery"; 
}

char const* knnQueryCreator::getPluginVersion() const noexcept 
{ 
    //std::cout<<"getPluginVersion"<<std::endl;
    return "1"; 
}

PluginFieldCollection const* knnQueryCreator::getFieldNames() noexcept 
{ 
    //std::cout<<"getFieldNames"<<std::endl;
    return &mFC; 
}

IPluginV2* knnQueryCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept 
{
    //std::cout<<"createPlugin"<<std::endl;
    try 
    {
        int sample = 0;  // 默认值
        // 遍历字段集合,查找名为 "num_samples" 的参
        for (int i = 0; i < fc->nbFields; ++i) {
            const PluginField& field = fc->fields[i];
            if (strcmp(field.name, "attr") == 0) {
                // 验证参数类型和维度(确保是 int32 标量)
                if (field.type == PluginFieldType::kINT32) {
                    sample = *static_cast<const int*>(field.data);
                }
            }
        }
       // std::cout<<"numSamples: "<<numSamples<<std::endl;
        return new knnQuery(sample);
    }
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
    return nullptr;
}

IPluginV2* knnQueryCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept 
{
    //std::cout<<"deserializePlugin"<<std::endl;
    try
    {
        // This object will be deleted when the network is destroyed, which will
        // call Concat::destroy()
        IPluginV2Ext* plugin = new knnQuery(serialData, serialLength);
        plugin->setPluginNamespace(mNamespace.c_str());
        return plugin;
    }
    catch (std::exception const& e)
    {
        caughtError(e);
    }
    return nullptr;
}

void knnQueryCreator::setPluginNamespace(char const* libNamespace) noexcept
{  
	//std::cout<<"setPluginNamespace"<<std::endl;        
	mNamespace = libNamespace;
}

char const* knnQueryCreator::getPluginNamespace() const noexcept
{    
	//std::cout<<"getPluginNamespace"<<std::endl;      
	return mNamespace.c_str();
}

knnQuery.cu

#include <stdio.h>
#include <stdlib.h>

#include "NvInfer.h"
#include "knnQuery.h"
#include <cuda_runtime.h>

namespace nvinfer1
{
namespace plugin
{

#define THREADS_PER_BLOCK 256
#define DIVUP(a, b) ((a + b - 1) / b) 

__device__ void swap_float(float *x, float *y)
{
    float tmp = *x;
    *x = *y;
    *y = tmp;
}


__device__ void swap_int(int *x, int *y)
{
    int tmp = *x;
    *x = *y;
    *y = tmp;
}


__device__ void reheap(float *dist, int *idx, int k)
{
    int root = 0;
    int child = root * 2 + 1;
    while (child < k)
    {
        if(child + 1 < k && dist[child+1] > dist[child])
            child++;
        if(dist[root] > dist[child])
            return;
        swap_float(&dist[root], &dist[child]);
        swap_int(&idx[root], &idx[child]);
        root = child;
        child = root * 2 + 1;
    }
}


__device__ void heap_sort(float *dist, int *idx, int k)
{
    int i;
    for (i = k - 1; i > 0; i--)
    {
        swap_float(&dist[0], &dist[i]);
        swap_int(&idx[0], &idx[i]);
        reheap(dist, idx, i);
    }
}


__device__ int get_bt_idx(int idx, const int *offset)
{
    int i = 0;
    // while (1)
    // {
    //     if (idx < offset[i])
    //         break;
    //     else
    //         i++;
    // }
    return i;
}


__global__ void knnquery_cuda_kernel(int m, int nsample, const float *__restrict__ xyz, const float *__restrict__ new_xyz, const int *__restrict__ offset, const int *__restrict__ new_offset, int *__restrict__ idx, float *__restrict__ dist2) {
    // input: xyz (n, 3) new_xyz (m, 3)
    // output: idx (m, nsample) dist2 (m, nsample)
    int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (pt_idx >= m) return;

    new_xyz += pt_idx * 3;
    idx += pt_idx * nsample;
    dist2 += pt_idx * nsample;
    int bt_idx = get_bt_idx(pt_idx, new_offset);
    int start;
    if (bt_idx == 0)
        start = 0;
    else
        start = offset[bt_idx - 1];
    int end = offset[bt_idx];

    float new_x = new_xyz[0];
    float new_y = new_xyz[1];
    float new_z = new_xyz[2];

    float best_dist[100];
    int best_idx[100];
    for(int i = 0; i < nsample; i++){
        best_dist[i] = 1e10;
        best_idx[i] = start;
    }
    for(int i = start; i < end; i++){
        float x = xyz[i * 3 + 0];
        float y = xyz[i * 3 + 1];
        float z = xyz[i * 3 + 2];
        float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + (new_z - z) * (new_z - z);
        if (d2 < best_dist[0]){
            best_dist[0] = d2;
            best_idx[0] = i;
            reheap(best_dist, best_idx, nsample);
        }
    }
    heap_sort(best_dist, best_idx, nsample);
    for(int i = 0; i < nsample; i++){
        idx[i] = best_idx[i];
        dist2[i] = best_dist[i];
    }
}


void knnquery_cuda_launcher(int m, int nsample, const float *xyz, const float *new_xyz, const int *offset, const int *new_offset, int *idx, float *dist2, cudaStream_t stream) {
    // input: new_xyz: (m, 3), xyz: (n, 3), idx: (m, nsample)
    dim3 blocks(DIVUP(m, THREADS_PER_BLOCK));
    dim3 threads(THREADS_PER_BLOCK);
    knnquery_cuda_kernel<<<blocks, threads, 0>>>(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2);
}

}
}

CMakeLists.txt

file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)

在TensorRT/plugin/inferPlugin.cpp的开头添加

#include "knnQuery/knnQuery.h"

并在initLibNvInferPlugins函数中添加

initializePlugin<nvinfer1::plugin::knnQueryCreator>(logger, libNamespace);

在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加

knnQuery

在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR、TRT_OUT_DIR,再重新编译tensorrt。

tensorrt推理测试

运行下面的命令把onnx 转为engine模型:

TensorRT-10.6.0.26/bin/trtexec --onnx=knnquery.onnx --saveEngine=knnquery.engine

编写python推理脚本:

import numpy as np
import tensorrt as trt
import common


logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("knnquery.engine", "rb") as f, trt.Runtime(logger) as runtime:
    engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)

# nsample = 8   
xyz =  np.loadtxt("xyz.txt").reshape(47358, 3)
new_xyz = np.loadtxt("new_xyz.txt").reshape(47358, 3)
offset = np.array([47358]).astype(np.int32)
new_offset = np.array([47358]).astype(np.int32)

# np.copyto(inputs[0].host, nsample)
np.copyto(inputs[0].host, xyz.ravel())
np.copyto(inputs[1].host, new_xyz.ravel())
np.copyto(inputs[2].host, offset.ravel())
np.copyto(inputs[3].host, new_offset.ravel())

output = common.do_inference(context,engine=engine, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print(output)  

您可能感兴趣的与本文相关的镜像

PyTorch 2.5

PyTorch 2.5

PyTorch
Cuda

PyTorch 是一个开源的 Python 机器学习库,基于 Torch 库,底层由 C++ 实现,应用于人工智能领域,如计算机视觉和自然语言处理

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

给算法爸爸上香

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值