导出onnx模型
import torch
import numpy as np
import pointnet2_utils
class CustomModel(torch.nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
def forward(self, new_xyz, xyz, radius, nsample): #torch.Size([1, 2048, 3]) torch.Size([1, 20000, 3]) 0.04 64
return pointnet2_utils.ball_query(new_xyz, xyz, radius, nsample)
model = CustomModel().cuda()
new_xyz = torch.from_numpy(np.loadtxt("new_xyz.txt").reshape(1, 2048, 3)).cuda().to(torch.float32)
xyz = torch.from_numpy(np.loadtxt("xyz.txt").reshape(1, 20000, 3)).cuda().to(torch.float32)
radius = 0.04
nsample = 64
torch.onnx.export(model, (new_xyz, xyz, radius, nsample), "ball_query.onnx", opset_version=13)
其中pointnet2_utils来自https://github.com/erikwijmans/Pointnet2_PyTorch。
导出onnx模型结构如下:


编写tensorrt插件
采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建ballQuery文件夹,添加下面文件:
ballQuery.h
#ifndef TRT_BALL_QUERY_PLUGIN_H
#define TRT_BALL_QUERY_PLUGIN_H
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/cuda_utils.h"
#include <vector>
#include <cstring>
namespace nvinfer1
{
namespace plugin
{
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius, int nsample, const float *new_xyz, const float *xyz, int *idx, cudaStream_t stream);
class BallQuery : public nvinfer1::IPluginV2DynamicExt
{
public:
BallQuery(float radis, int sample);
BallQuery(void const* data, size_t length);
~BallQuery() override;
// 插件基本信息
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
// 输出维度计算
nvinfer1::DimsExprs getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;
// 初始化与销毁
int initialize() noexcept override;
void terminate() noexcept override;
// 执行相关
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept override;
// 数据类型与格式支持
DataType getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,int nbOutputs) noexcept override;
// 配置插件
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;
// 序列化
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
// 其他接口
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(char const* libNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
private:
int mSample; // 采样点数(从输入获取)
float mRadis;
std::string mPluginNamespace;
Dims mInputDims; // 点云输入维度 (B, N, 3)
Dims mSampleDims; // 采样点数输入维度(通常是标量或 (B,))
};
class BallQueryCreator : public nvinfer1::IPluginCreator
{
public:
BallQueryCreator();
~BallQueryCreator() override = default;
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
PluginFieldCollection const* getFieldNames() noexcept override;
IPluginV2* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override;
IPluginV2* deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept override;
void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
const char* getPluginNamespace() const noexcept override;
private:
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
std::string mNamespace;
};
} // namespace plugin
} // namespace nvinfer1
#endif // TRT_BALL_QUERY_PLUGIN_H
ballQuery.cpp
#include "ballQuery.h"
#include "common/dimsHelpers.h"
using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::BallQuery;
using nvinfer1::plugin::BallQueryCreator;
// 插件实现
BallQuery::BallQuery(float radis, int sample): mRadis(radis), mSample(sample)
{
//std::cout<<"BallQuery"<<std::endl;
}
BallQuery::BallQuery(void const* data, size_t length)
{
// 验证序列化数据长度是否匹配预期(float类型半径 + int类型采样点数)
PLUGIN_ASSERT(length == sizeof(float) + sizeof(int));
// 从序列化数据中解析半径和采样点数
const char* ptr = static_cast<const char*>(data);
// 读取半径(float类型)
mRadis = *static_cast<const float*>(static_cast<const void*>(ptr));
ptr += sizeof(float);
// 读取最大采样点数(int类型)
mSample = *static_cast<const int*>(static_cast<const void*>(ptr));
ptr += sizeof(int);
}
BallQuery::~BallQuery() {}
// 插件基本信息
char const* BallQuery::getPluginType() const noexcept
{
//std::cout<<"getPluginType"<<std::endl;
return "ball_query";
}
char const* BallQuery::getPluginVersion() const noexcept
{
//std::cout<<"getPluginVersion"<<std::endl;
return "1";
}
int BallQuery::getNbOutputs() const noexcept
{
//std::cout<<"getNbOutputs"<<std::endl;
return 1;
}
nvinfer1::DimsExprs BallQuery::getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
//std::cout << "getOutputDimensions" << std::endl;
// 验证输出索引和输入数量
PLUGIN_ASSERT(outputIndex == 0 && nbInputs == 4);
// 构建输出维度: (B, M)
nvinfer1::DimsExprs outputDims;
outputDims.nbDims = 3;
// 第一个维度为批次大小 B (与输入保持一致)
outputDims.d[0] = exprBuilder.constant(static_cast<int>(inputs[0].d[0]->getConstantValue()));
outputDims.d[1] = exprBuilder.constant(static_cast<int>(inputs[0].d[1]->getConstantValue()));
outputDims.d[2] = exprBuilder.constant(mSample);
return outputDims;
}
// 初始化
int BallQuery::initialize() noexcept
{
return STATUS_SUCCESS;
}
// 销毁资源
void BallQuery::terminate() noexcept {}
// 执行核函数
int BallQuery::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
//std::cout << "enqueue" << std::endl;
try
{
// 输入维度校验
PLUGIN_ASSERT(inputDesc[0].dims.nbDims == 3); // new_xyz: (B, M, 3)
PLUGIN_ASSERT(inputDesc[1].dims.nbDims == 3); // xyz: (B, N, 3)
PLUGIN_ASSERT(inputDesc[2].dims.nbDims == 0); // radius: 标量
PLUGIN_ASSERT(inputDesc[3].dims.nbDims == 0); // nsample: 标量
PLUGIN_ASSERT(outputDesc[0].dims.nbDims == 3); // 输出索引: (B, M, nsample)
// 数据类型校验
PLUGIN_ASSERT(inputDesc[0].type == nvinfer1::DataType::kFLOAT); // new_xyz为float
PLUGIN_ASSERT(inputDesc[1].type == nvinfer1::DataType::kFLOAT); // xyz为float
PLUGIN_ASSERT(inputDesc[2].type == nvinfer1::DataType::kFLOAT); // radius为float
PLUGIN_ASSERT(inputDesc[3].type == nvinfer1::DataType::kINT32); // nsample为int64
PLUGIN_ASSERT(outputDesc[0].type == nvinfer1::DataType::kINT32); // 输出索引为int32
// 解析输入参数
const float* new_xyz = static_cast<const float*>(inputs[0]); // (B, M, 3)
const float* xyz = static_cast<const float*>(inputs[1]); // (B, N, 3)
int* idx = static_cast<int*>(outputs[0]); // 输出索引矩阵 (B, M, nsample)
// 解析维度信息
int b = inputDesc[0].dims.d[0]; // 批次大小
int m = inputDesc[0].dims.d[1]; // 采样点数量 (new_xyz的点数)
int n = inputDesc[1].dims.d[1]; // 原始点数量 (xyz的点数)
// const float* inputs2 = static_cast<const float*>(inputs[2]);
// float* h_inputs2 = new float;
// cudaMemcpy(h_inputs2, inputs2, sizeof(float), cudaMemcpyDeviceToHost);
// mRadis = h_inputs2[0];
// const int* inputs3 = static_cast<const int*>(inputs[3]);
// int* h_inputs3 = new int;
// cudaMemcpy(h_inputs3, inputs3, sizeof(int), cudaMemcpyDeviceToHost);
// mSample = h_inputs3[0];
//std::cout << "radius: " << mRadis << " sample: " << mSample << std::endl;
// 调用球查询核函数包装器
query_ball_point_kernel_wrapper(
b, // 批次大小
n, // 原始点数量
m, // 采样点数量
mRadis, // 查询半径
mSample, // 最大采样点数
new_xyz, // 采样点坐标 (B, M, 3)
xyz, // 原始点坐标 (B, N, 3)
idx, // 输出索引 (B, M, nsample)
stream // CUDA流
);
return STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
return -1;
}
// 工作空间大小
size_t BallQuery::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept
{
PLUGIN_ASSERT(nbInputs == 4 && nbOutputs == 1);
// 解析输入维度: new_xyz (B, M, 3), xyz (B, N, 3)
int B = inputs[0].dims.d[0]; // 批次大小
int M = inputs[0].dims.d[1]; // 采样点数量
int N = inputs[1].dims.d[1]; // 原始点数量
// 工作空间用途:存储每个查询点的临时候选索引(假设最多可能需要N个候选点)
// 每个索引为int32类型,按最大可能需求分配
size_t tempIndicesSize = B * M * N * sizeof(int32_t);
// 额外预留一些空间用于其他临时变量(如距离计算缓冲区)
size_t extraSpace = B * M * N * sizeof(float); // 存储距离的float缓冲区
return tempIndicesSize + extraSpace;
}
// 输出数据类型:索引为INT32
DataType BallQuery::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept
{
//std::cout<<"getOutputDataType"<<std::endl;
PLUGIN_ASSERT(index == 0 && nbInputs == 4);
return DataType::kINT32;
}
// 支持的格式
bool BallQuery::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{
//std::cout << "supportsFormatCombination" << std::endl;
PLUGIN_ASSERT(pos < nbInputs + nbOutputs);
if (pos == 0)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 1)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 2)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 3)
{
return (inOut[pos].type == nvinfer1::DataType::kINT32)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 4)
{
return (inOut[pos].type == nvinfer1::DataType::kINT32)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
return false;
}
void BallQuery::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
//std::cout<<"configurePlugin"<<std::endl;
try
{
PLUGIN_ASSERT(nbInputs == 4 && nbOutputs == 1); // 确认2个输入和1个输出
}
catch (std::exception const& e)
{
caughtError(e);
}
}
// 序列化:仅需保存采样点数
size_t BallQuery::getSerializationSize() const noexcept
{
//std::cout<<"getSerializationSize"<<std::endl;
return sizeof(float) + sizeof(int) ;
}
void BallQuery::serialize(void* buffer) const noexcept
{
//std::cout << "serialize" << std::endl;
// 序列化半径和采样点数
char* ptr = static_cast<char*>(buffer);
// 保存半径(mRadis)
*static_cast<float*>(static_cast<void*>(ptr)) = mRadis;
ptr += sizeof(float);
// 保存最大采样点数(mSample)
*static_cast<int*>(static_cast<void*>(ptr)) = mSample;
ptr += sizeof(int);
}
// 克隆插件
nvinfer1::IPluginV2DynamicExt* BallQuery::clone() const noexcept
{
//std::cout<<"clone"<<std::endl;
try
{
return new BallQuery(mRadis, mSample);
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
void BallQuery::destroy() noexcept
{
delete this;
}
// 命名空间管理
void BallQuery::setPluginNamespace(char const* pluginNamespace) noexcept
{
//std::cout<<"setPluginNamespace"<<std::endl;
try
{
mPluginNamespace = pluginNamespace;
}
catch (std::exception const& e)
{
caughtError(e);
}
}
char const* BallQuery::getPluginNamespace() const noexcept
{
//std::cout<<"getPluginNamespace"<<std::endl;
return mPluginNamespace.c_str();
}
// 插件创建器实现
PluginFieldCollection BallQueryCreator::mFC{};
std::vector<PluginField> BallQueryCreator::mPluginAttributes;
BallQueryCreator::BallQueryCreator()
{
//std::cout<<"BallQueryCreator"<<std::endl;
mPluginAttributes.clear();
mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kFLOAT32));
mPluginAttributes.emplace_back(nvinfer1::PluginField("attr2", nullptr, nvinfer1::PluginFieldType::kINT32));
mFC.nbFields = mPluginAttributes.size();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* BallQueryCreator::getPluginName() const noexcept
{
//std::cout<<"getPluginName"<<std::endl;
return "ball_query";
}
char const* BallQueryCreator::getPluginVersion() const noexcept
{
//std::cout<<"getPluginVersion"<<std::endl;
return "1";
}
PluginFieldCollection const* BallQueryCreator::getFieldNames() noexcept
{
//std::cout<<"getFieldNames"<<std::endl;
return &mFC;
}
IPluginV2* BallQueryCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
//std::cout<<"createPlugin"<<std::endl;
try
{
float radius = 0.0f;
int sample = 0;
for (int i = 0; i < fc->nbFields; ++i) {
const PluginField& field = fc->fields[i];
if (strcmp(field.name, "attr1") == 0) {
if (field.type == PluginFieldType::kFLOAT32) {
radius = *static_cast<const float*>(field.data);
}
}
if (strcmp(field.name, "attr2") == 0) {
if (field.type == PluginFieldType::kINT32) {
sample = *static_cast<const int*>(field.data);
}
}
}
//std::cout<<"radius: "<<radius<<" sample: "<<sample<<std::endl;
return new BallQuery(radius, sample);
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* BallQueryCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept
{
//std::cout<<"deserializePlugin"<<std::endl;
try
{
// This object will be deleted when the network is destroyed, which will
// call Concat::destroy()
IPluginV2Ext* plugin = new BallQuery(serialData, serialLength);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
void BallQueryCreator::setPluginNamespace(char const* libNamespace) noexcept
{
//std::cout<<"setPluginNamespace"<<std::endl;
mNamespace = libNamespace;
}
char const* BallQueryCreator::getPluginNamespace() const noexcept
{
//std::cout<<"getPluginNamespace"<<std::endl;
return mNamespace.c_str();
}
ballQuery.cu
#include <stdio.h>
#include <stdlib.h>
#include "NvInfer.h"
#include "ballQuery.h"
#include <cuda_runtime.h>
namespace nvinfer1
{
namespace plugin
{
__global__ void query_ball_point_kernel(int b, int n, int m, float radius,
int nsample,
const float *__restrict__ new_xyz,
const float *__restrict__ xyz,
int *__restrict__ idx) {
int batch_index = blockIdx.x;
xyz += batch_index * n * 3;
new_xyz += batch_index * m * 3;
idx += m * nsample * batch_index;
int index = threadIdx.x;
int stride = blockDim.x;
float radius2 = radius * radius;
for (int j = index; j < m; j += stride) {
float new_x = new_xyz[j * 3 + 0];
float new_y = new_xyz[j * 3 + 1];
float new_z = new_xyz[j * 3 + 2];
for (int k = 0, cnt = 0; k < n && cnt < nsample; ++k) {
float x = xyz[k * 3 + 0];
float y = xyz[k * 3 + 1];
float z = xyz[k * 3 + 2];
float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < radius2) {
if (cnt == 0) {
for (int l = 0; l < nsample; ++l) {
idx[j * nsample + l] = k;
}
}
idx[j * nsample + cnt] = k;
++cnt;
}
}
}
}
void query_ball_point_kernel_wrapper(int b, int n, int m, float radius,
int nsample, const float *new_xyz,
const float *xyz, int *idx, cudaStream_t stream) {
query_ball_point_kernel<<<b, opt_n_threads(m), 0, stream>>>(
b, n, m, radius, nsample, new_xyz, xyz, idx);
CUDA_CHECK_ERRORS();
}
}
}
CMakeLists.txt
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
在TensorRT/plugin/inferPlugin.cpp的开头添加
#include "ballQuery/ballQuery.h"
并在initLibNvInferPlugins函数中添加
initializePlugin<nvinfer1::plugin::BallQueryCreator>(logger, libNamespace);
在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加
ballQuery
在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR、TRT_OUT_DIR,再重新编译tensorrt。
tensorrt推理测试
运行下面的命令把onnx 转为engine模型:
TensorRT-10.6.0.26/bin/trtexec --onnx=ball_query.onnx --saveEngine=ball_query.engine
编写python推理脚本:
import numpy as np
import tensorrt as trt
import common
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("ball_query.engine", "rb") as f, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
new_xyz = np.loadtxt("new_xyz.txt").reshape(1, 2048, 3)
xyz = np.loadtxt("xyz.txt").reshape(1, 20000, 3)
radius = 0.04
nsample = 64
np.copyto(inputs[0].host, new_xyz.ravel())
np.copyto(inputs[1].host, xyz.ravel())
np.copyto(inputs[2].host, radius)
np.copyto(inputs[3].host, nsample)
output = common.do_inference(context,engine=engine, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print(output)

682

被折叠的 条评论
为什么被折叠?



