导出onnx模型
import torch
import numpy as np
import pointnet2_utils
class CustomModel(torch.nn.Module):
def __init__(self):
super(CustomModel, self).__init__()
def forward(self, features, idx, weight):
return tpointnet2_utils.three_interpolate(features, idx, weight) p
model = CustomModel().cuda()
features = torch.randn(1, 256, 256).cuda()
idx = torch.randn(1, 512, 3).cuda().to(torch.int32)
weight = torch.randn(1, 512, 3).cuda()
np.savetxt("features.txt", features.reshape(256, 256).detach().cpu().numpy())
np.savetxt("idx.txt", idx.reshape(512, 3).detach().cpu().numpy())
np.savetxt("weight.txt", weight.reshape(512, 3).detach().cpu().numpy())
torch.onnx.export(model, (features, idx, weight), "three_interpolate.onnx", opset_version=13)
其中pointnet2_utils来自https://github.com/erikwijmans/Pointnet2_PyTorch。
导出onnx模型结构如下:

编写tensorrt插件
采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在https://developer.nvidia.com/tensorrt/download/10x下载TensorRT-10.6.0.26的库,然后在https://github.com/NVIDIA/TensorRT/tree/v10.6.0下载源代码。
在TensorRT/plugin下新建threeInterpolate文件夹,添加下面文件:
threeInterpolate.h
#ifndef TRT_THREE_INTERPOLATE_PLUGIN_H
#define TRT_THREE_INTERPOLATE_PLUGIN_H
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include "common/cuda_utils.h"
#include <vector>
#include <cstring>
namespace nvinfer1
{
namespace plugin
{
void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
const float *points, const int *idx,
const float *weight, float *out, cudaStream_t stream);
class ThreeInterpolate : public nvinfer1::IPluginV2DynamicExt
{
public:
ThreeInterpolate();
ThreeInterpolate(void const* data, size_t length);
~ThreeInterpolate() override;
// 插件基本信息
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
// 输出维度计算
nvinfer1::DimsExprs getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;
// 初始化与销毁
int initialize() noexcept override;
void terminate() noexcept override;
// 执行相关
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept override;
// 数据类型与格式支持
DataType getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept override;
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,int nbOutputs) noexcept override;
// 配置插件
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept override;
// 序列化
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
// 其他接口
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
void destroy() noexcept override;
void setPluginNamespace(char const* libNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
private:
std::string mPluginNamespace;
Dims mInputDims; // 点云输入维度 (B, N, 3)
Dims mSampleDims; // 采样点数输入维度(通常是标量或 (B,))
};
class ThreeInterpolateCreator : public nvinfer1::IPluginCreator
{
public:
ThreeInterpolateCreator();
~ThreeInterpolateCreator() override = default;
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
PluginFieldCollection const* getFieldNames() noexcept override;
IPluginV2* createPlugin(char const* name, PluginFieldCollection const* fc) noexcept override;
IPluginV2* deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept override;
void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
const char* getPluginNamespace() const noexcept override;
private:
static PluginFieldCollection mFC;
static std::vector<PluginField> mPluginAttributes;
std::string mNamespace;
};
} // namespace plugin
} // namespace nvinfer1
#endif // TRT_THREE_INTERPOLATE_PLUGIN_H
threeInterpolate.cpp
#include "threeInterpolate.h"
#include "common/dimsHelpers.h"
using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::ThreeInterpolate;
using nvinfer1::plugin::ThreeInterpolateCreator;
// 插件实现
ThreeInterpolate::ThreeInterpolate()
{
//std::cout<<"ThreeInterpolate"<<std::endl;
}
ThreeInterpolate::ThreeInterpolate(void const* data, size_t length)
{
}
ThreeInterpolate::~ThreeInterpolate() {}
// 插件基本信息
char const* ThreeInterpolate::getPluginType() const noexcept
{
//std::cout<<"getPluginType"<<std::endl;
return "three_interpolate";
}
char const* ThreeInterpolate::getPluginVersion() const noexcept
{
//std::cout<<"getPluginVersion"<<std::endl;
return "1";
}
int ThreeInterpolate::getNbOutputs() const noexcept
{
//std::cout<<"getNbOutputs"<<std::endl;
return 1;
}
nvinfer1::DimsExprs ThreeInterpolate::getOutputDimensions(int outputIndex,
const nvinfer1::DimsExprs* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
//std::cout << "getOutputDimensions" << std::endl;
// 验证输出索引和输入数量
PLUGIN_ASSERT(outputIndex == 0 && nbInputs == 3);
nvinfer1::DimsExprs outputDims;
outputDims.nbDims = 3;
outputDims.d[0] = exprBuilder.constant(static_cast<int>(inputs[0].d[0]->getConstantValue()));
outputDims.d[1] = exprBuilder.constant(static_cast<int>(inputs[0].d[1]->getConstantValue()));
outputDims.d[2] = exprBuilder.constant(static_cast<int>(inputs[1].d[1]->getConstantValue()));
return outputDims;
}
// 初始化
int ThreeInterpolate::initialize() noexcept
{
return STATUS_SUCCESS;
}
// 销毁资源
void ThreeInterpolate::terminate() noexcept {}
// 执行核函数
int ThreeInterpolate::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
//std::cout << "enqueue" << std::endl;
try
{
// 输入维度校验
PLUGIN_ASSERT(inputDesc[0].dims.nbDims == 3);
PLUGIN_ASSERT(inputDesc[1].dims.nbDims == 3);
PLUGIN_ASSERT(inputDesc[2].dims.nbDims == 3);
PLUGIN_ASSERT(outputDesc[0].dims.nbDims == 3);
// 数据类型校验
PLUGIN_ASSERT(inputDesc[0].type == nvinfer1::DataType::kFLOAT);
PLUGIN_ASSERT(inputDesc[1].type == nvinfer1::DataType::kINT32);
PLUGIN_ASSERT(inputDesc[2].type == nvinfer1::DataType::kFLOAT);
PLUGIN_ASSERT(outputDesc[0].type == nvinfer1::DataType::kFLOAT);
// 提取维度信息
int b = inputDesc[0].dims.d[0]; // 批次大小
int c = inputDesc[0].dims.d[1]; // 特征通道数
int m = inputDesc[0].dims.d[2]; // 原始点数量
int n = inputDesc[1].dims.d[1]; // 插值后点数量
// 验证维度有效性
PLUGIN_ASSERT(b > 0 && c > 0 && m > 0 && n > 0);
PLUGIN_ASSERT(inputDesc[1].dims.d[2] == 3); // 每个目标点对应3个源点索引
PLUGIN_ASSERT(inputDesc[2].dims.d[1] == n && inputDesc[2].dims.d[2] == 3); // weight维度与idx匹配
// 转换输入输出数据指针
const float* points = static_cast<const float*>(inputs[0]); // 源点特征 (B, C, M)
const int* idx = static_cast<const int*>(inputs[1]); // 索引 (B, N, 3)
const float* weight = static_cast<const float*>(inputs[2]); // 权重 (B, N, 3)
float* out = static_cast<float*>(outputs[0]); // 输出特征 (B, C, N)
// 调用CUDA核函数包装器执行插值计算
three_interpolate_kernel_wrapper(b, c, m, n, points, idx, weight,out, stream);
return STATUS_SUCCESS;
}
catch (std::exception const& e)
{
caughtError(e);
}
return -1;
}
// 工作空间大小
size_t ThreeInterpolate::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nbInputs, const nvinfer1::PluginTensorDesc* outputs,
int nbOutputs) const noexcept
{
return 0;
}
// 输出数据类型:索引为INT32
DataType ThreeInterpolate::getOutputDataType(int index, DataType const* inputTypes, int nbInputs) const noexcept
{
//std::cout<<"getOutputDataType"<<std::endl;
PLUGIN_ASSERT(index == 0 && nbInputs == 3);
return DataType::kFLOAT;
}
// 支持的格式:输入float32,输出int32,均为线性格式
bool ThreeInterpolate::supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept
{
//std::cout << "supportsFormatCombination" << std::endl;
PLUGIN_ASSERT(pos < nbInputs + nbOutputs);
if (pos == 0)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 1)
{
return (inOut[pos].type == nvinfer1::DataType::kINT32)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 2)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 3)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
return false;
}
void ThreeInterpolate::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept
{
//std::cout<<"configurePlugin"<<std::endl;
try
{
PLUGIN_ASSERT(nbInputs == 3 && nbOutputs == 1); // 确认3个输入和1个输出
}
catch (std::exception const& e)
{
caughtError(e);
}
}
// 序列化:仅需保存采样点数
size_t ThreeInterpolate::getSerializationSize() const noexcept
{
//std::cout<<"getSerializationSize"<<std::endl;
return 0 ;
}
void ThreeInterpolate::serialize(void* buffer) const noexcept
{
}
// 克隆插件
nvinfer1::IPluginV2DynamicExt* ThreeInterpolate::clone() const noexcept
{
//std::cout<<"clone"<<std::endl;
try
{
return new ThreeInterpolate();
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
void ThreeInterpolate::destroy() noexcept
{
delete this;
}
// 命名空间管理
void ThreeInterpolate::setPluginNamespace(char const* pluginNamespace) noexcept
{
//std::cout<<"setPluginNamespace"<<std::endl;
try
{
mPluginNamespace = pluginNamespace;
}
catch (std::exception const& e)
{
caughtError(e);
}
}
char const* ThreeInterpolate::getPluginNamespace() const noexcept
{
//std::cout<<"getPluginNamespace"<<std::endl;
return mPluginNamespace.c_str();
}
// 插件创建器实现
PluginFieldCollection ThreeInterpolateCreator::mFC{};
std::vector<PluginField> ThreeInterpolateCreator::mPluginAttributes;
ThreeInterpolateCreator::ThreeInterpolateCreator()
{
//std::cout<<"ThreeInterpolateCreator"<<std::endl;
mPluginAttributes.clear();
mFC.nbFields = mPluginAttributes.size();
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* ThreeInterpolateCreator::getPluginName() const noexcept
{
//std::cout<<"getPluginName"<<std::endl;
return "three_interpolate";
}
char const* ThreeInterpolateCreator::getPluginVersion() const noexcept
{
//std::cout<<"getPluginVersion"<<std::endl;
return "1";
}
PluginFieldCollection const* ThreeInterpolateCreator::getFieldNames() noexcept
{
//std::cout<<"getFieldNames"<<std::endl;
return &mFC;
}
IPluginV2* ThreeInterpolateCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
//std::cout<<"createPlugin"<<std::endl;
try
{
return new ThreeInterpolate();
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* ThreeInterpolateCreator::deserializePlugin(char const* name, void const* serialData, size_t serialLength) noexcept
{
//std::cout<<"deserializePlugin"<<std::endl;
try
{
// This object will be deleted when the network is destroyed, which will
// call Concat::destroy()
IPluginV2Ext* plugin = new ThreeInterpolate();
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
void ThreeInterpolateCreator::setPluginNamespace(char const* libNamespace) noexcept
{
//std::cout<<"setPluginNamespace"<<std::endl;
mNamespace = libNamespace;
}
char const* ThreeInterpolateCreator::getPluginNamespace() const noexcept
{
//std::cout<<"getPluginNamespace"<<std::endl;
return mNamespace.c_str();
}
threeInterpolate.cu
#include <stdio.h>
#include <stdlib.h>
#include "NvInfer.h"
#include "threeInterpolate.h"
#include <cuda_runtime.h>
namespace nvinfer1
{
namespace plugin
{
// input: points(b, c, m), idx(b, n, 3), weight(b, n, 3)
// output: out(b, c, n)
__global__ void three_interpolate_kernel(int b, int c, int m, int n,
const float *__restrict__ points,
const int *__restrict__ idx,
const float *__restrict__ weight,
float *__restrict__ out) {
int batch_index = blockIdx.x;
points += batch_index * m * c;
idx += batch_index * n * 3;
weight += batch_index * n * 3;
out += batch_index * n * c;
const int index = threadIdx.y * blockDim.x + threadIdx.x;
const int stride = blockDim.y * blockDim.x;
for (int i = index; i < c * n; i += stride) {
const int l = i / n;
const int j = i % n;
float w1 = weight[j * 3 + 0];
float w2 = weight[j * 3 + 1];
float w3 = weight[j * 3 + 2];
int i1 = idx[j * 3 + 0];
int i2 = idx[j * 3 + 1];
int i3 = idx[j * 3 + 2];
out[i] = points[l * m + i1] * w1 + points[l * m + i2] * w2 +
points[l * m + i3] * w3;
}
}
void three_interpolate_kernel_wrapper(int b, int c, int m, int n,
const float *points, const int *idx,
const float *weight, float *out, cudaStream_t stream) {
three_interpolate_kernel<<<b, opt_block_config(n, c), 0, stream>>>(
b, c, m, n, points, idx, weight, out);
CUDA_CHECK_ERRORS();
}
}
}
CMakeLists.txt
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
在TensorRT/plugin/inferPlugin.cpp的开头添加
#include "threeInterpolate/threeInterpolate.h"
并在initLibNvInferPlugins函数中添加
initializePlugin<nvinfer1::plugin::ThreeInterpolateCreator>(logger, libNamespace);
在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加
threeInterpolate
在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR、TRT_OUT_DIR,再重新编译tensorrt。
tensorrt推理测试
运行下面的命令把onnx 转为engine模型:
TensorRT-10.6.0.26/bin/trtexec --onnx=three_interpolate.onnx --saveEngine=three_interpolate.engine
编写python推理脚本:
import numpy as np
import tensorrt as trt
import common
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("three_interpolate.engine", "rb") as f, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
features = np.loadtxt("features.txt").reshape(1, 256, 256).astype(np.float32)
idx = np.loadtxt("idx.txt").reshape(1, 512, 3).astype(np.int32)
weight = np.loadtxt("weight.txt").reshape(1, 512, 3).astype(np.float32)
np.copyto(inputs[0].host, features.ravel())
np.copyto(inputs[1].host, idx.ravel())
np.copyto(inputs[2].host, weight.ravel())
output = common.do_inference(context,engine=engine, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
print(output[0].reshape(1, 256, 512))
np.savetxt("output[0].txt", output[0].reshape(256, 512))

458

被折叠的 条评论
为什么被折叠?



