【TensorRT】20250826 日志 - 开启FP16的问题

TensorRT开启FP16的调试经验

ModelEngine·创作计划征文活动 10w+人浏览 1.4k人参与

博主最近遇到一个新模型需要转 TensorRT Engine 的任务,打算采用 Ckpt - ONNX - Engine的方式,遇到了一些小问题,记录一下

一、问题

这是一个极其类似 BERT模型 的新模型,所以按照博主的经验来看,其实在构建时设置build_config fp16就可以直接转为 fp16精度。但是这次却出现构建完成后推理时出现 差不匹配比例 过大的问题

二、解决方式

按照比较正常的思路来看,遇到这种问题,就得找一些出现差异的数据,开始抠图了,找找在哪些节点出现输出差异较大的问题

但是这种方式需要的时间比较长,博主也不要做到那么精细的量化,所以也有快刀斩乱麻的糙汉子做法。先将 network 打印出来看看

def print_trt_network_layers(network):
    print(f"网络总层数: {network.num_layers}")
    for layer_idx in range(network.num_layers):
        layer = network.get_layer(layer_idx)
        print(f"\n层 {layer_idx}:")
        print(f"  名称: {layer.name}")
        print(f"  类型: {layer.type}")  # 层类型(如CONVOLUTION、LAYER_NORMALIZATION等)
        print(f"  精度: {layer.precision}")
        print(f"  输入数量: {layer.num_inputs}")
        for i in range(layer.num_inputs):
            in_tensor = layer.get_input(i)
            if in_tensor:
                print(f"    输入{i}: 名称={in_tensor.name}, 形状={in_tensor.shape}, 精度={in_tensor.dtype}")
        print(f"  输出数量: {layer.num_outputs}")
        for i in range(layer.num_outputs):
            out_tensor = layer.get_output(i)
            if out_tensor:
                print(f"    输出{i}: 名称={out_tensor.name}, 形状={out_tensor.shape}, 精度={out_tensor.dtype}")

看看有哪些类型的节点,将一些敏感的节点显示设置为 FP32,设置 build_config 为 FP16 即可

如果不行的话,多实验几次,调整一下敏感节点

# -*- coding: utf-8 -*-
import tensorrt as trt
import argparse
import onnx
import ctypes

TRT_LOGGER = trt.Logger(trt.Logger.INFO)
handle = ctypes.CDLL("libnvinfer_plugin.so", mode=ctypes.RTLD_GLOBAL)
if not handle:
    raise RuntimeError("Could not load plugin library. Is `libnvinfer_plugin.so` on your LD_LIBRARY_PATH?")

def print_trt_network_layers(network):
    print(f"网络总层数: {network.num_layers}")
    for layer_idx in range(network.num_layers):
        layer = network.get_layer(layer_idx)
        print(f"\n层 {layer_idx}:")
        print(f"  名称: {layer.name}")
        print(f"  类型: {layer.type}")  # 层类型(如CONVOLUTION、LAYER_NORMALIZATION等)
        print(f"  精度: {layer.precision}")
        print(f"  输入数量: {layer.num_inputs}")
        for i in range(layer.num_inputs):
            in_tensor = layer.get_input(i)
            if in_tensor:
                print(f"    输入{i}: 名称={in_tensor.name}, 形状={in_tensor.shape}, 精度={in_tensor.dtype}")
        print(f"  输出数量: {layer.num_outputs}")
        for i in range(layer.num_outputs):
            out_tensor = layer.get_output(i)
            if out_tensor:
                print(f"    输出{i}: 名称={out_tensor.name}, 形状={out_tensor.shape}, 精度={out_tensor.dtype}")

def build_engine(onnx_file, engine_file, precision="fp32", workspace=1 << 30):
    # TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE)
    builder = trt.Builder(TRT_LOGGER)
    config = builder.create_builder_config()
    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) 
    # config.set_preview_feature(trt.PreviewFeature.FASTER_DYNAMIC_SHAPES_0805, True) # 有问题
    network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))

    # 解析 ONNX
    parser = trt.OnnxParser(network, TRT_LOGGER)
    success = parser.parse_from_file(onnx_file)
    if not success:
        for error in range(parser.num_errors):
            print(parser.get_error(error))
        raise RuntimeError("Failed to parse ONNX model")

    # 设置 shape
    profile = builder.create_optimization_profile()
    min_shape = (1, 4096)
    shape = (2, 4096)
    max_shape = (3, 4096)
    input_names = ["input_ids", "attention_mask", "token_type_ids", "segment_ids"]
    for name in input_names:
        profile.set_shape(name, min=min_shape, opt=shape, max=max_shape)
    config.add_optimization_profile(profile)
    
    # fp16 设置
    if precision == "fp16":
        # sensitive_layers = ["Attention", "LayerNorm", "attention", "layernorm"]
        sensitive_layers = ["LayerNorm", "layernorm"]
        index_int_layers = [
            trt.LayerType.CONSTANT,        # 整数权重常量层
            trt.LayerType.SHUFFLE,         # 维度调整(reshape/transpose等)
            trt.LayerType.CONCATENATION,   # 张量拼接
            trt.LayerType.GATHER,          # 按索引收集
            trt.LayerType.SCATTER,         # 按索引分散
            trt.LayerType.TOPK,            # 取前K个元素(需索引)
            trt.LayerType.SELECT,          # 按条件选择(需索引定位)
            trt.LayerType.SLICE,           # 静态切片(避免动态场景风险)
            trt.LayerType.ONE_HOT,         # 独热编码(整数标签输入)
            trt.LayerType.SHAPE
        ]
        for i in range(network.num_layers):
            layer = network.get_layer(i)
            if layer.type in index_int_layers:
                continue

            is_sensitive = any(name in layer.name for name in sensitive_layers)
            # 非敏感层
            if not is_sensitive:
                if layer.precision == trt.DataType.FLOAT:
                    layer.precision = trt.DataType.HALF  
                    # 输入输出设置
                    for inp_idx in range(layer.num_inputs):
                        input_tensor = layer.get_input(inp_idx)
                        if input_tensor.dtype == trt.DataType.FLOAT:
                            input_tensor.dtype = trt.DataType.HALF
                    for out_idx in range(layer.num_outputs):
                        if layer.get_output_type(out_idx) == trt.DataType.FLOAT:
                            layer.set_output_type(out_idx, trt.DataType.HALF)
            # 敏感层显示设置
            else:
                if layer.precision == trt.DataType.FLOAT:
                    layer.precision = trt.DataType.FLOAT 
                if layer.precision == trt.DataType.HALF:
                    layer.precision = trt.DataType.FLOAT

        config.flags |= 1 << int(trt.BuilderFlag.FP16)

    print_trt_network_layers(network)

    # build engine
    engine = builder.build_serialized_network(network, config)
    with open(engine_file, 'wb') as f:
        f.write(engine)


def main(args):
    precision = "fp16"
    if args.use_fp16 == "False":
        print("use fp32")
        precision = "fp32"
    else:
        print("use fp16")
    build_engine(args.onnx_file, args.output_path, precision=precision)


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument('--use_fp16', default=True)
    argparser.add_argument('-onnx', '--onnx_file',  required=True)
    argparser.add_argument('-output', '--output_path',  required=True)
    args = argparser.parse_args()
    main(args)

三、总结

通过上述方式转出来的 TensorRT Engine 大小为170多 MB,原本为 300多 MB。单次推理耗时从 25ms 降低到 17ms 左右。博主觉得这个模型其实有些其他问题的,这样一个模型不该有这么高的耗时,但是确实没有时间研究了,本次只是想记录一下这种 "土方子",若日后有时间和兴趣再进行探讨吧

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

GG_Bond21

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值