#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "NvOnnxConfig.h"
#include "NvOnnxParser.h"
#include <iostream>
#include <fstream>
using namespace nvinfer1;
int main()
{
// 创建推理构建器
IBuilder* builder = createInferBuilder(gLogger);
INetworkDefinition* network = builder->createNetwork();
// 定义输入张量
ITensor* input = network->addInput("input", DataType::kFLOAT, Dims3(3, 224, 224));
// 添加Softmax层
ISoftMaxLayer* softmax = network->addSoftMax(*outputTensor);
// 设置网络的输出
softmax->getOutput(0)->setName("output");
network->markOutput(*softmax->getOutput(0));
// 创建推理引擎
ICudaEngine* engine = builder->buildCudaEngine(*network);
// 将引擎序列化为ONNX格式
nvonnxparser::IOnnxConfig* onnxConfig = nvonnxparser::createONNXConfig();
nvonnxparser::IONNXParser* onnxParser = nvonnxparser::createONNXParser(*network, gLogger);
onnxParser->parse(onnxConfig);
// 保存ONNX模型到文件
std::ofstream onnxFile("test.onnx", std::ios::binary);
onnxFile.write(onnxParser->getModelAsText(onnxConfig), strlen(onnxParser->getModelAsText(onnxConfig)));
onnxFile.close();
// 释放资源
onnxParser->destroy();
onnxConfig->destroy();
engine->destroy();
network->destroy();
builder->destroy();
return 0;
}
Tensorrt 定义网络
于 2024-03-05 19:36:29 首次发布