OpenPPL NN RISC-V架构自定义算子开发指南
概述
在深度学习推理框架中,自定义算子的开发是扩展框架功能的重要手段。本文将详细介绍如何在OpenPPL NN项目的RISC-V架构上添加自定义算子。通过本文,开发者可以了解从算子定义到实现的全流程,掌握在RISC-V架构上扩展深度学习算子的关键技术。
自定义算子开发流程
在OpenPPL NN的RISC-V架构上添加自定义算子主要分为四个步骤:
- 算子参数定义与解析
- 算子定义实现
- 算子调用接口开发
- 核心计算函数实现
1. 算子参数定义与解析
算子参数定义与解析是算子开发的第一步,需要定义算子所需的参数结构体,并实现参数的解析逻辑。这部分内容与其他架构类似,开发者可以参考X86架构的实现方式。
2. 算子定义实现
2.1 创建算子定义类
在ppl.nn/src/ppl/nn/engines/riscv/optimizer/ops/onnx目录下创建算子定义文件,包括头文件(<opname>_op.h)和实现文件(<opname>_op.cc)。
以Clip算子为例,其定义类结构如下:
class ClipOp final : public RiscvOptKernel {
public:
ClipOp(const ir::Node* node) : RiscvOptKernel(node) {}
ppl::common::RetCode Init(const OptKernelOptions& options) override;
KernelImpl* CreateKernelImpl() const override;
ppl::common::RetCode SelectFormat(const InputOutputInfo& info,
std::vector<ppl::common::dataformat_t>* selected_input_formats,
std::vector<ppl::common::dataformat_t>* selected_output_formats) override;
ppl::common::RetCode SelectDataType(const InputOutputInfo& info,
std::vector<ppl::common::datatype_t>* selected_input_data_types,
std::vector<ppl::common::datatype_t>* selected_output_data_types) override;
};
2.2 维度计算函数注册
维度计算函数负责根据输入张量的维度推断输出张量的维度。在Init函数中,开发者需要将维度计算函数注册到infer_dims_func_成员变量中。
OpenPPL NN提供了默认的维度计算函数GenericInferDims,对于简单的算子可以直接使用。对于需要特殊维度计算的算子,开发者可以自定义lambda表达式实现计算逻辑。
2.3 数据类型选择函数
RISC-V架构支持混合精度计算,因此需要实现SelectDataType函数来选择合适的数据类型。该函数根据输入数据类型和算子支持的类型,确定输出数据类型。
RetCode ClipOp::SelectDataType(const InputOutputInfo& info,
std::vector<datatype_t>* selected_input_data_types,
std::vector<datatype_t>* selected_output_data_types) {
if (DATATYPE_FLOAT16 == selected_input_data_types->at(0)) {
selected_output_data_types->at(0) = DATATYPE_FLOAT16;
} else if (DATATYPE_FLOAT32 == selected_input_data_types->at(0)) {
selected_output_data_types->at(0) = DATATYPE_FLOAT32;
}
return RC_SUCCESS;
}
2.4 数据排布选择函数
RISC-V架构支持多种数据排布格式,包括NDARRAY、N4CX和N8CX。SelectFormat函数负责选择最适合当前算子的数据排布。
RetCode ClipOp::SelectFormat(const InputOutputInfo& info,
vector<dataformat_t>* selected_input_formats,
vector<dataformat_t>* selected_output_formats) {
if (DATAFORMAT_N8CX == selected_input_formats->at(0)) {
selected_output_formats->at(0) = DATAFORMAT_N8CX;
} else if(DATAFORMAT_N4CX == selected_input_formats->at(0)) {
selected_output_formats->at(0) = DATAFORMAT_N4CX;
}
return RC_SUCCESS;
}
2.5 创建Kernel实现
CreateKernelImpl函数负责创建算子调用接口。根据算子是否需要参数,可以使用以下两种方式:
// 无参数算子
KernelImpl* TestOp::CreateKernelImpl() const {
return CreateKernelImplWithoutParam<TestKernel>();
}
// 有参数算子
KernelImpl* TestOp::CreateKernelImpl() const {
return CreateKernelImplWithParam<TestKernel>(param_.get());
}
2.6 算子注册
完成算子定义后,需要在OptKernelCreatorManager中注册算子:
REGISTER_OPT_KERNEL_CREATOR("", "Test", 7, 11, TestOp);
参数说明:
- 第一个参数:算子域(domain)
- 第二个参数:算子类型(op_type)
- 第三、四个参数:支持的opset范围
- 第五个参数:算子定义类名
3. 算子调用接口实现
在ppl.nn/src/ppl/nn/engines/riscv/kernels/onnx目录下创建算子调用接口文件。
基本结构如下:
class ClipKernel : public RiscvKernel {
public:
ClipKernel(const ir::Node* node) : RiscvKernel(node) {}
private:
ppl::common::RetCode DoExecute(KernelExecContext*) override;
bool CanDoExecute(const KernelExecContext&) const override;
};
DoExecute函数是核心实现,负责调用具体的计算函数并将结果写入输出张量。CanDoExecute函数用于在执行前检查条件是否满足。
4. 核心计算函数实现
4.1 函数声明
建议将kernel函数声明放在ppl.nn/src/ppl/nn/engines/riscv/impls/include/ppl/kernel/riscv目录下,按数据类型组织:
<data_type>/<opname>.h
4.2 函数实现
实现文件建议放在ppl.nn/src/ppl/nn/engines/riscv/impls/src/ppl/kernel/riscv目录下:
<data_type>/<opname>/<opname>_<data_type>.cpp
由于kernel函数与框架耦合度低,开发者可以根据算子特点自由组织代码结构。建议考虑以下因素:
- 数据排布处理
- 向量化优化
- 内存访问优化
- RISC-V特定指令集利用
开发建议
- 性能优化:充分利用RISC-V的向量扩展指令集(RVV)来提升计算性能
- 内存优化:注意数据排布对内存访问效率的影响
- 精度处理:合理处理不同数据类型的计算精度
- 边界条件:充分考虑各种边界条件的处理
- 测试验证:开发完成后进行充分的单元测试和性能测试
通过本文的指导,开发者可以系统地完成OpenPPL NN在RISC-V架构上的自定义算子开发工作,为深度学习推理在RISC-V平台上的应用扩展提供支持。
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考



