TensorRT build engine的流程
- 创建builder:
- 创建网络定义 builder --> network
- 配置参数: builder --> config
- 生成engine: builder --> engine(network, config)
- 序列化保存: engine --> serialize
- 释放资源: delete
1. 创建builder
#include <iostream>
#include <NvInfer.h>
class TRTLogger : public nvinfer1::ILogger
{
void log(Severity severity, const char *msg) noexcept override
{
// 屏蔽INFO级别日志
if (severity != Severity::kINFO)
{
std::cout << msg << std::endl;
}
}
}gLogger;
int main() {
// 1. 创建builder
TRTLogger logger; // logger是必要的,用来捕捉warning和info等
nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(logger);
return 0;
}
nvinfer1下的数据类型,都是nvinfer1::Data ptr, 后面通过指针操作
Severity 是一个枚举类型,用于指定日志消息的严重程度。在 TensorRT 中,Severity 枚举类型定义了四个常量,分别为:
kINTERNAL_ERROR:内部错误,表示程序出现了无法处理的错误或异常。
kERROR:错误,表示程序执行过程中出现了错误,但程序可以继续运行。
kWARNING:警告,表示程序执行过程中出现了一些可能会导致问题的情况。
kINFO:信息,表示程序执行过程中的一些有用的信息。
在使用 TensorRT 进行模型推理时,我们可以使用 ILogger 接口来定义日志输出,例如将日志输出到控制台或文件中。在实现 ILogger 接口时,需要实现 log() 函数来处理日志消息,该函数的参数包括一个 Severity 类型的参数和一个 const char* 类型的参数,分别表示日志消息的严重程度和内容。
在示例程序中,我们定义了一个名为 TRTLogger 的类,它是 ILogger 接口的实现类。在 TRTLogger 类中,我们通过重载 log() 函数来过滤掉 Severity 为 kINFO 的日志消息,并将其他日志消息输出到控制台中。这样可以避免输出过多无用的信息,从而更好地了解程序的运行情况。
1.1 复习抽象基类和静态函数
nvinfer1::IBuilder *builder = nvinfer1::createInferBuilder(logger);
首先这里的nvinfer1::IBuilder 是一个抽象基类, createInferBuilder()是一个静态函数, 抽象基类需要new开辟空间
#include <iostream>
class Shape {
public:
virtual void draw() = 0; // 纯虚函数
static Shape* createShape(const std::string& type);
};
class Rectangle : public Shape {
public:
void draw() override {
std::cout << "Drawing a rectangle." << std::endl;
}
};
class Circle : public Shape {
public:
void draw() override {
std::cout << "Drawing a circle." << std::endl;
}
};
Shape* Shape::createShape(const std::string& type) {
if (type == "Rectangle") {
return new Rectangle();
} else if (type == "Circle") {
return new Circle();
} else {
return nullptr;
}
}
int main() {
Shape* s1 = new Rectangle();
Shape* s2 = Shape::createShape("Circle");
s1->draw();
s2->draw();
delete s1;
delete s2;
return 0;
}
在这段代码中,Shape 是一个抽象基类,它包含一个纯虚函数 draw(),这意味着该函数没有实现,子类必须覆盖该函数并提供自己的实现。这样做的目的是将 Shape 定义为一个通用概念,但由于它本身是不完整的,因此不能实例化。通过将 Shape 设计为抽象基类,它的子类必须实现 draw(),从而保证了 Shape 的通用性和可扩展性。
Shape 的子类 Rectangle 和 Circle 继承了 Shape 类,并实现了 draw() 函数。Rectangle 类的 draw() 函数打印“Drawing a rectangle.”,而 Circle 类的 draw() 函数打印“Drawing a circle.”。
Shape::createShape() 是一个静态函数,它接受一个 std::string 类型的参数 type,并返回一个 Shape 类型的指针。在此示例中,它检查 type 的值,如果为 “Rectangle”,则创建一个 Rectangle 对象,如果为 “Circle”,则创建一个 Circle 对象。如果 type 的值无法识别,则返回 nullptr。这个函数的目的是封装对象的创建,并提供一种灵活的方式来创建不同的对象,而无需暴露对象创建的具体细节。在这个例子中,createShape() 方法是静态的,这意味着可以直接通过类名来调用该方法,而无需先创建类的对象。
2. 定义网络
- build -> network
- 这里用显性batch, 显性 batch 是指在网络定义时指定 batch size 的大小。这是一种固定的 batch 大小,不会随输入数据而变化,而且可以通过 TensorRT 进行优化。
- 显性 batch 是指在网络定义时指定 batch size 的大小。这是一种固定的 batch 大小,不会随输入数据而变化,而且可以通过 TensorRT 进行优化。
- net->addInput()
- KFLOAT, KHALF, KINT8, KINT32 是 TensorRT 中用于表示张量数据类型的枚举类型
- Dim4{1, 3, 1, 1} batch, channel, height_Size, width_size
- network->

本文介绍了如何使用TensorRT构建网络定义,配置参数,创建Engine并进行序列化保存。主要步骤包括创建Builder,定义网络结构,配置BuilderConfig,创建Engine并将其序列化为磁盘文件。此外,还展示了如何在另一个程序中反序列化Engine,创建执行上下文并进行推理操作。整个过程涉及到CUDA内存管理和异步数据传输。
最低0.47元/天 解锁文章
635

被折叠的 条评论
为什么被折叠?



