
在飞桨黑客松比赛的第三期,飞桨社区核心开发者李健铭参与了CINN算子开发方向的任务。
本文将由李健铭分享其开发过程,手把手教大家为神经网络编译器CINN增加One-Hot算子。

任务介绍
CINN(Compiler Infrastructure for Neural Networks)是一种在不改变模型代码的条件下加速飞桨模型运行速度的深度学习编译器。不同于深度学习框架算子,深度学习编译器算子的粒度更细,算子数目也更少,因此在算子融合和自动调优方面具有更大的优势。
在对接上层框架时,编译器会将上层的框架算子进一步拆分为若干基础算子,这样做的目的一方面是为了减少算子开发的工作量,仅实现有限的基础算子便可以组合出大量的上层框架算子;另一方面便于算子融合技术在编译器中可以实现跨算子自动融合,减少最终执行时的kernel数目和访存开销,达到更好的性能;此外,结合自动调优技术使得编译器可以自动优化融合后的kernel,提升kernel性能。
我完成的是One-Hot算子的开发任务。该任务需要具备编译器的基础知识,了解神经网络的基本原理。如果你还学习过编译器框架LLVM,开发过程会更轻松。如果没学过也没关系,我们可以参照CINN中现有的基础算子,学习相关API的使用方法。这个任务的重点是理解和运用CINN IR。
设计文档
算子介绍
One-Hot算子(在本项目中,该算子函数名为OneHot,后文将统一称为OneHot)接受5个参数,输出1个张量。算子的参数含义如下。
indices索引张量。
on_value索引位置填充的值。
off_value非索引位置填充的值。
axis填充的轴。
depth填充的长度。
算子的功能是按照张量indices表示的索引位置填充数值,生成一个新张量。在新张量中,indices索引位置上的值为on_value,其它位置上的值为off_value。算子的功能描述可能比较难理解,我们可以看一些算子计算示例。
OneHot(
indices=[0, 2, 2],
on_value=1,
off_value=0,
depth=4,
axis=0,
dtype="float32"
)
代码输出
# [[1. 0. 0]
# [0. 0. 1]
# [0. 0. 1]
# [0. 0. 0]]
OneHot(
indices=[0, 2, 2],
on_value=1,
off_value=0,
depth=4,
axis=-1,
dtype="float32"
)
代码输出

本文由飞桨社区核心开发者李健铭分享CINN中One-Hot算子的开发过程,详细解释了算子的功能、参数和实现方法,包括前端和后端的开发、单测编写等关键步骤,帮助读者理解深度学习编译器算子开发。
最低0.47元/天 解锁文章
1832

被折叠的 条评论
为什么被折叠?



