背景
- 基于pytorch的量化模型转为onnx过程中出现异常
- 使用tensorrt的pytorch_quantization实现进行模型int8量化
- 在转为onnx的过程中,出现异常
RuntimeError: NYI: Named tensors are not supported with the tracer
原因分析
原因定位
- 首先搜索错误:RuntimeError: NYI: Named tensors are not supported with the tracer,发现这种错误还不少,这句话的意思也很明确,tracer不支持命名张量,更深入的原理大家可以自己研究(搜索pytorch转onnx的原理即可)
- 尝试了一些方法,都不管用,包括升级torch、升级onnx、调整opset_version等,均不可行,mmdeploy的issue上有人出现了类似问题,官方说已经修复(issue),也不适用于笔者的问题
- 使用deepseek分析、kimi、Gemini分析,也没有结果
- 至此,只能回头再看报错信息,尝试定位具体的报错位置
if min_amax <= epsilon: # Treat amax smaller than minimum representable of fp16 0
zero_amax_mask = (amax <= epsilon)
# 异常报错的位置在此 ********************
scale[zero_amax_mask] = 0 # Value quantized with amax=0 should all be 0
outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)
if min_amax <= epsilon:
scale[zero_amax_mask] = 1. # Return 1 makes more sense for values quantized to 0 with amax=0
- 最初怀疑scale或者zero_amax_mask是named tensor,但打印names发现均为None,所以不是这个问题
- 将zero_amax_mask = (amax <= epsilon),scale[zero_amax_mask] = 0 两行注释掉,异常出现在scale[zero_amax_mask] = 1行,将该行也注释掉,通过了
- 至此,说明报错的原因和scale[zero_amax_mask] 的赋值这两行相关,终于有点儿头绪了
原因分析
- 继续通过goole检索资料,收获甚微
- 直接将所在函数扔给大模型,大模型全面分析了这段代码,分析的很好,没任何毛病(顺便汇报下模型效果:kimi、deepseek、Gemini分析的很优秀,豆包、千问、文心一言分析的比较一般)
- 继续研scale[zero_amax_mask] 赋值的问题,发现问题可能出在布尔掩码的生成和索引操作上,而这种索引操作又是动态的(因为mask依赖于amax的值),而tracer对于动态跟踪一直都不行
- 那么为什么 ONNX 跟踪器无法处理这些动态行为呢?ONNX 跟踪器的工作原理是通过记录程序执行过程中对张量的操作,从而构建出一个静态的计算图。 这个计算图需要是静态的,意味着它的结构和操作序列在导出时就必须确定下来,不能在运行时根据数据动态改变。当代码中出现像布尔索引这样操作的执行依赖于张量的值的情况时,跟踪器就难以静态地捕捉到所有可能的执行路径。 对于 ONNX 跟踪器来说,它需要知道计算图的每一步操作都是固定的,才能将其转换为 ONNX 格式。 当遇到条件判断或基于值的动态修改时,它就无法有效地进行静态分析和转换。
解决方法
使用 torch.where 替换布尔索引是调整代码以避免 ONNX 导出异常的最推荐方法
zero_amax_mask = (amax <= epsilon)
# 使用torch.where
scale = torch.where(zero_amax_mask, torch.zeros_like(scale), scale) # 弾S zero_amax_mask 为 True 彗¶Lscale 设置为 0L佐¦佈~Y侾]彌~A低~_ scale 佀¼
# 注释掉以下3行
# if min_amax <= epsilon: # Treat amax smaller than minimum representable of fp16 0
# zero_amax_mask = (amax <= epsilon)
# scale[zero_amax_mask] = 0 # Value quantized with amax=0 should all be 0
outputs = torch.clamp((inputs * scale).round_(), min_bound, max_bound)
# 注释掉以下2行
# if min_amax <= epsilon:
# scale[zero_amax_mask] = 1. # Return 1 makes more sense for values quantized to 0 with amax=0
# 使用torch.where
scale = torch.where(zero_amax_mask, torch.ones_like(scale), scale) # 弾S zero_amax_mask 为 True 彗¶Lscale 设置为 1L佐¦佈~Y侾]彌~A低~_ scale 佀¼