手把手教你为神经网络编译器CINN增加One-Hot算子

本文由飞桨社区核心开发者李健铭分享CINN中One-Hot算子的开发过程,详细解释了算子的功能、参数和实现方法,包括前端和后端的开发、单测编写等关键步骤,帮助读者理解深度学习编译器算子开发。

1164e3cf34d1151725e0935be7e7863b.gif

在飞桨黑客松比赛的第三期,飞桨社区核心开发者李健铭参与了CINN算子开发方向的任务。

本文将由李健铭分享其开发过程,手把手教大家为神经网络编译器CINN增加One-Hot算子。

2b9dc7e7d9546a0cda6594d82cbcead3.png

9973258bb44f4ef50a9a6d38b92868ba.png任务介绍

CINN(Compiler Infrastructure for Neural Networks)是一种在不改变模型代码的条件下加速飞桨模型运行速度的深度学习编译器。不同于深度学习框架算子,深度学习编译器算子的粒度更细,算子数目也更少,因此在算子融合和自动调优方面具有更大的优势。

在对接上层框架时,编译器会将上层的框架算子进一步拆分为若干基础算子,这样做的目的一方面是为了减少算子开发的工作量,仅实现有限的基础算子便可以组合出大量的上层框架算子;另一方面便于算子融合技术在编译器中可以实现跨算子自动融合,减少最终执行时的kernel数目和访存开销,达到更好的性能;此外,结合自动调优技术使得编译器可以自动优化融合后的kernel,提升kernel性能。

我完成的是One-Hot算子的开发任务。该任务需要具备编译器的基础知识,了解神经网络的基本原理。如果你还学习过编译器框架LLVM,开发过程会更轻松。如果没学过也没关系,我们可以参照CINN中现有的基础算子,学习相关API的使用方法。这个任务的重点是理解和运用CINN IR。

fecf168e4e2d5784e4489d8df917c380.png设计文档

d0d5b90bb92f9afd263bc57b0d397633.png算子介绍

One-Hot算子(在本项目中,该算子函数名为OneHot,后文将统一称为OneHot)接受5个参数,输出1个张量。算子的参数含义如下。

  • indices索引张量。

  • on_value索引位置填充的值。

  • off_value非索引位置填充的值。

  • axis填充的轴。

  • depth填充的长度。

算子的功能是按照张量indices表示的索引位置填充数值,生成一个新张量。在新张量中,indices索引位置上的值为on_value,其它位置上的值为off_value。算子的功能描述可能比较难理解,我们可以看一些算子计算示例。

OneHot(
    indices=[0, 2, 2],
    on_value=1,
    off_value=0,
    depth=4,
    axis=0,
    dtype="float32"
)
  • 代码输出

# [[1. 0. 0]
#  [0. 0. 1]
#  [0. 0. 1]
#  [0. 0. 0]]

OneHot(
    indices=[0, 2, 2],
    on_value=1,
    off_value=0,
    depth=4,
    axis=-1,
    dtype="float32"
)
  • 代码输出


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值