PyTorch教程:扩展ONNX导出器的算子支持
tutorials PyTorch tutorials. 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials
概述
在深度学习模型部署过程中,将PyTorch模型导出为ONNX格式是一个常见需求。然而,有时我们会遇到PyTorch算子不被ONNX支持的情况。本教程将详细介绍如何扩展ONNX导出器的算子支持,包括三种常见场景:
- 覆盖现有PyTorch算子的实现
- 使用自定义ONNX算子
- 支持自定义PyTorch算子
准备工作
在开始之前,请确保满足以下条件:
- 安装PyTorch 2.6或更高版本
- 熟悉目标PyTorch算子的使用
- 已完成ONNX Script基础教程的学习
- 准备好使用ONNX Script实现的算子代码
覆盖现有PyTorch算子的实现
当ONNX导出器不支持某个PyTorch算子时,我们需要为其提供自定义实现。以下是一个完整示例:
import torch
import onnxscript
from onnxscript import opset18 as op
# 定义使用目标算子的模型
class Model(torch.nn.Module):
def forward(self, input_x, input_y):
return torch.ops.aten.add.Tensor(input_x, input_y)
# 自定义实现函数(注意参数签名必须匹配)
def custom_aten_add(self, other, alpha: float = 1.0):
if alpha != 1.0:
alpha = op.CastLike(alpha, other)
other = op.Mul(other, alpha)
# 为了区分自定义实现,我们调换输入顺序
return op.Add(other, self)
# 导出模型时提供自定义转换表
x, y = torch.tensor([1.0]), torch.tensor([2.0])
onnx_program = torch.onnx.export(
Model().eval(),
(x, y),
dynamo=True,
custom_translation_table={
torch.ops.aten.add.Tensor: custom_aten_add,
},
)
onnx_program.optimize()
通过检查导出的ONNX模型,可以确认自定义实现已成功应用。我们还可以使用ONNX Runtime验证结果:
result = onnx_program(x, y)[0]
torch.testing.assert_close(result, torch.tensor([3.0]))
使用自定义ONNX算子
某些情况下,我们可能需要使用特定运行时提供的自定义ONNX算子。以下示例展示了如何使用ONNX Runtime提供的Gelu算子:
class GeluModel(torch.nn.Module):
def forward(self, input_x):
return torch.ops.aten.gelu(input_x)
# 定义自定义命名空间
microsoft_op = onnxscript.values.Opset(domain="com.microsoft", version=1)
# 使用ONNX Script实现自定义Gelu
@onnxscript.script(microsoft_op)
def custom_aten_gelu(self: FLOAT, approximate: str = "none") -> FLOAT:
return microsoft_op.Gelu(self)
# 导出模型
onnx_program = torch.onnx.export(
GeluModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.aten.gelu.default: custom_aten_gelu,
},
)
onnx_program.optimize()
支持自定义PyTorch算子
对于完全自定义的PyTorch算子,我们需要提供完整的ONNX实现。以下是一个自定义"加后取整"算子的示例:
# 定义并注册自定义PyTorch算子
@torch.library.custom_op("mylibrary::add_and_round_op", mutates_args=())
def add_and_round_op(input: torch.Tensor) -> torch.Tensor:
return torch.round(input + input)
@add_and_round_op.register_fake
def _add_and_round_op_fake(tensor_x):
return torch.empty_like(tensor_x)
# 使用自定义算子的模型
class AddAndRoundModel(torch.nn.Module):
def forward(self, input):
return add_and_round_op(input)
# ONNX实现
def onnx_add_and_round(input):
return op.Round(op.Add(input, input))
# 导出模型
onnx_program = torch.onnx.export(
AddAndRoundModel().eval(),
(x,),
dynamo=True,
custom_translation_table={
torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,
},
)
onnx_program.optimize()
最佳实践
- 参数签名匹配:自定义实现函数的参数签名必须与原始PyTorch算子完全匹配
- 类型注解:所有属性参数必须添加类型注解
- 性能优化:导出后使用
optimize()
方法优化ONNX图 - 结果验证:始终使用ONNX Runtime验证导出结果的正确性
结论
通过本教程,我们学习了如何扩展ONNX导出器的算子支持。无论是覆盖现有实现、使用自定义ONNX算子,还是支持全新的PyTorch算子,PyTorch都提供了灵活的方式来实现这些需求。掌握这些技术可以帮助我们更灵活地将PyTorch模型部署到各种支持ONNX的运行时环境中。
在实际应用中,建议先检查ONNX官方支持的算子列表,只有在必要时才使用自定义实现。同时,记得充分测试自定义算子的正确性和性能表现。
tutorials PyTorch tutorials. 项目地址: https://gitcode.com/gh_mirrors/tuto/tutorials
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考