0x0. 前言
这篇文章基于TVM 0.8.0.dev版本。在【从零开始学深度学习编译器】五,TVM Relay以及Pass简介 这篇推文中已经简单介绍了Relay和Pass机制。但对Pass的基础设施(Pass Infrastructure)和Relay树结构都没有详细介绍,所以这篇文章主要介绍一下Pass Infrastructure和Relay树结构,再基于这些关键的基础知识详细了解一下Constant Folding Pass,相信读者读完这篇文章会对TVM的Pass有更深的理解,并且在阅读其它Pass和实现自定义Pass时可以很Relax。
0x1. Pass Infrastructure
首先来看Pass Infrastructure,基于官方文档进行介绍。
在讲解Pass通用的注册和运行流程前,先来介绍一下TVM的Pass Infrastructure。参考官方文档:https://tvm.apache.org/docs/dev/pass_infra.html 。
Relay 和 TVM IR 都包含一系列优化passes,可提高模型的性能指标,例如平均推理速度、内存占用或特定设备的功耗。 TVM有一套标准优化方法以及特定于机器学习的优化方法,包括常量折叠、死代码消除、运算符布局更改、算符融合、缓冲区处理和循环变换等。 每一个Pass都使用在traversal期间和/或之前收集的分析结果来构造ir-to-ir的pass。
然而,随着TVM的迅速发展,需要一种更系统、更有效的方法来管理这些passes。此外,一个可以管理跨TVM堆栈不同层(如Relay和tir)的passes的通用框架,为开发人员快速原型化并将实现的passes插入系统铺平了道路。
例如,许多现有的生产编译器,如 GCC 和 LLVM,都采用pass manager来有效管理passes的执行。 最初管理 pass 很简单,因为 pass 的数量很少,但成熟的编译器将包含数百个单独的 pass。 Often external users will want to have custom passes correctly scheduled without having to modify a single handcrafted pass order.
同样,现代深度学习框架,如 Pytorch 和 MXNet Gluon,也有分别通过 Sequential 和 Block 启用pass-style层构建方案的趋势。 有了这样的结构,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松地构建神经网络。
Relay pass infra 的设计很大程度上受到 LLVM 中使用的分层pass manager和流行的深度学习框架中使用的block-style容器的启发。 pass infra 的主要目标包括:
- 实现更好的optimizer编程编排。 这允许用户灵活地定制和构建自己的优化管道。
- 提供一种用户友好的方式来调试passes。
- 减轻开发人员手动和分别解决passes之间的依赖关系。
- 为开发人员简化实现新passes的难度。 例如,我们允许用户在 Python 中实现一个 pass 并让 pass infra 操纵它的执行。
The Design
我们专注于为用户提供易于扩展的功能,让用户可以快速添加新passes而不会失去向后兼容性。 该设计包含后端和前端。 前者实现了 pass infra 的主要逻辑。 后者为用户提供简单的 API 进行交互,即允许用户快速创建自己的优化管道。
C++ Backend
我们提供了一个 PassInfo 对象来包含一个pass所需的基本信息。 name 是 pass 名称,opt_level 指示将启用 pass 的优化级别, required 表示执行某个 pass 所需的 pass(更多详细信息请参见include/tvm/ir/transform.h)。 例如,在注册pass的时候(将在后面介绍),pass开发人员可以指定pass的名称、将执行的优化级别和/或所需的pass。 opt_level 可用于帮助 pass infra 识别在用户提供的优化级别下运行时是否需要执行某个 pass。 required字段可以由pass infra用来解决pass依赖关系。
class PassInfoNode : public Object {
String name;
int opt_level;
Array<String> required;
};
PassContext
PassContext 带有用于优化pass的有用信息。 例如,它包含错误报告系统,因此pass的作者可以提供有关优化失败原因的注释。 PassContext 还旨在替换旧的BuildConfig,它用于帮助用户配置编译选项,包括优化级别和必需/禁用的pass等。例如,我们可能有一个配置,它在 opt_level=3 时执行所有pass,除开使用 PassContext 提供的 disabled_pass=xx禁用的一些passes 。 现在我们可以在 opt_level=3 处对所有passes进行全局处理,并排除禁用pass列表中的那些pass。
这个类是为方便用户编写Python而设计的,它的语法可以在特定的配置下执行优化。 此外,用户可以通过 PassContext::Current()以线程安全的方式获取某个程序范围内可用的context,因为ThreadLocalStore用于保存创建的pass context对象,关于ThreadLocalStore建议看这篇文章:https://zhuanlan.zhihu.com/p/61587053,TVM模仿Java中的ThreadLocalStore在C++层自己实现了用来管理线程。 稍后将提供示例以展示我们如何使用 C++ 和 Python API 来创建使用pass context的编译管道。
class PassContextNode : public Object {
public:
ErrorReporter err_reporter;
int opt_level{
2};
tvm::Array<tvm::Expr> required_pass;
tvm::Array<tvm::Expr> disabled_pass;
};
class PassContext : public NodeRef {
public:
TVM_DLL static PassContext Create();
TVM_DLL static PassContext Current();
/* Other fields are omitted. */
private:
// The entry of a pass context scope.
TVM_DLL void EnterWithScope();
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Classes to get the Python `with` like syntax.
friend class tvm::With<PassContext>;
};
struct PassContextThreadLocalEntry {
/*! \brief The default pass context. */
PassContext default_context;
/*! \brief The current pass context. */
std::stack<PassContext> context_stack;
PassContextThreadLocalEntry() {
default_context = PassContext(make_node<PassContextNode>());
}
};
/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
PassContextThreadLocalStore;
Pass Constructs
pass infra 是以分层方式设计的,它可以在不同粒度的Relay/tir 程序下工作。 引入了一个纯虚拟类 PassNode 作为不同优化pass的基础。 此类包含几个必须由子类在modules, functions, or sequences of passes实现的虚拟方法。
class PassNode : Object {
virtual PassInfo Info() const = 0;
virtual Module operator()(const IRModule& mod
const PassContext& pass_ctx) const = 0;
};
成员函数展示了一个pass应该如何实现,例如它始终在特定context下工作在 IRModule中,所有的pass都被设计在一个Module to Module的管理器中。因此,由 pass infra 控制的优化将始终更新整个module。
已经创建了几个子类来实现不同类型的优化pass,例如,function-level passes, module-level passes, and sequential passes。 每个子类本身都可以充当pass管理器。 例如,他们可以收集所需的passes并执行它们或基于给定的元数据构建依赖关系图。 它们的完整定义可以在src/relay/ir/transform.cc 和 src/ir/transform.cc 中找到。
Module-Level Passes
Module Level Passes主要用于全局和过程间优化 (IPO),类似于 LLVM 中使用的module pass。 Relay 中一些典型的 pass 需要一个模块的global picture,比如 A-normal form conversion 和 lambda lifting等,都属于这个集合。 在此级别,用户甚至可以在一个module中添加和/或删除function。
class ModulePassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
// Other members/methods are omitted
};
pass_info 维护module-level pass所需的信息。 pass_func 实现了真正的optimization。 例如,我们可能需要对module执行死代码消除。 我们可以在 pass_func 中实现算法并让它在module上运行。 然后它将删除死代码,包括module中未使用的函数。 请注意,该字段被设计为一个packed function,所以这个优化不仅可以使用C++还可以使用Python来实现。
Function-Level Passes
Function-level passes用于为给定的 Relay/tir module实现各种内部函数级优化。 它一次从module的函数列表中获取一个函数以进行优化,并生成一个重写的 Relay Function 或 tir PrimFunc。 大多数pass可以归入这一类,例如Relay中的常见子表达式消除和inference simplification 以及tir中的向量化和flattening storage等。
请注意,此级别的passes范围是 Relay Function或 tir PrimFunc。 因此,我们无法通过这些passes添加或删除函数,因为它们不知道全局信息。
class FunctionPassNode : PassNode {
PassInfo pass_info;
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
bool SkipFunction(const Function& func) const;
// Other members/methods are omitted...
};
pass_info 与我们刚刚在Module pass 中描述的相同。 pass_func 需要一个函数进行优化,它还需要一个Module,因为我们可能会使用它来报告错误。 一个函数可以用“SkipOptimization”注释,以便在优化过程中被忽略。
Sequential Passes
SequentialPass 类似于 Pytorch nn.Sequential,它包含许多用于执行的passes。
class SequentialPassNode : PassNode {
PassInfo pass_info;
// Passes need to be executed.
Array<Pass> passes;
bool PassEnabled(const PassInfo& info) const;
Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};
目前在Relay中只有少数passes 被放入这组中。 例如,FoldScaleAxis 需要在内部调度 ForwardFoldScaleAxis 和 BackwardFoldScaleAxis。 此外,建议先完成BackwardFoldScaleAxis。 因此,该pass是SequentialPass的理想候选者。
以下代码显示了如何调用sequential pass中的各个pass。
Module SequentialNode::operator()(const Module& module,
const PassContext& pass_ctx) const {
Module mod = module;
for (const Pass& pass : passes) {
ICHECK(pass.defined()) << "Found undefined pass for optimization.";
const PassInfo& pass_info = pass->Info();
if (!PassEnabled(pass_info)) continue;
for (const auto& it : pass_info->required) {
const auto* name = it.as<tvm::ir::StringImm>();
ICHECK(name);
mod = GetPass(name->value)(mod, pass_ctx);
}
mod = pass(mod, pass_ctx);
}
return mod;
}
在调用pass时,我们首先检查是否启用了此pass。 这是通过首先检查用户是否明确禁用该pass,然后检查它是否被用户指定为必需pass来完成的。 如果仍然不确定是否启用了此传递,则将检查其 opt_level。 只有当它的opt_level不低于pass context中配置的优化级别时,才会启用并因此执行此pass。
要执行pass,我们首先需要使用pass name在 TVM packed function注册表中已注册的pass。 这是可能的,因为每个pass都注册了一个 API 接口,我们将在后面展示。
Pass GetPass(const std::string& pass_name) {
using tvm::runtime::Registry;
std::string fpass_name = "relay._transform." + pass_name;
const auto* f = Registry::Get(fpass_name);
ICHECK(f != nullptr) << "Cannot find " << fpass_name
<< "to create the pass " << pass_name;
return (*f)();
}
提供了一些helper function来创建上述每种类型的Pass。 这些helper function也暴露给 Python 前端,以便用户可以方便地使用 Python API 来创建特定的 pass 对象。
Pass CreateFunctionPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreatePrimFuncPass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass CreateModulePass(
const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
String name,
Array<String> required);
Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);
Pass Registration
我们已经介绍了不同级别pass的概念和用于编译的context。 用户可以多么轻松地注册pass是一件有意义的事。,我们以constant folding为例。 这个 pass 已经被实现来折叠 Relay Function中的常量(在 tvm/src/relay/transforms/fold_constant.cc 中找到)。
提供了一个 API 来执行 Expr 到 Expr 的转换。
Expr FoldConstant(const Expr& expr);
为了将这个pass注册到pass infra,我们首先需要决定这个pass将在哪个级别执行。 由于常量折叠发生在单个函数上,我们应该直观地通过 CreateFunctionPass为其创建一个 FunctionPass。 pass_func 作为packed function返回,该函数在 IRModule 中的每个function上调用 Expr to Expr API。 {} 表示此pass不需要先决条件。 否则,pass开发人员必须识别并列出它们。
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {
});
}
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
} // namespace transform
为了允许其他 C++ 模块应用此pass,我们在 include/tvm/relay/transform.h中声明了一个free function,如下所示:
TVM_DLL Pass FoldConstant();
Python Frontend
python前端只需要一些简单的 APIs。 例如,我们可以为用户提供以下 APIs 来创建和执行一个 pass(完整的实现在 python/tvm/relay/transform.py 和 python/tvm/ir/transform.py 中提供)。 后端接收信息并决定它应该使用哪个函数来创建 Pass 对象。
PassContext
Python 前端为 PassContext 提供了一个包装器,通过覆盖 __enter__ 和 __exit__ 来启用 with 语法。 为用户提供了一个 current 静态方法来获取在特定范围内使用的上下文。
@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
def __enter__(self):
_transform.EnterPassContext(self)
return self
def __exit__(self, ptype, value, trace, config):
_transform.ExitPassContext(self)
@staticmethod
def current():
"""Return the current pass context."""
return _transform.GetCurrentPassContext()
PassContext 用于配置编译选项,包括优化级别和必需/禁用的pass。 它还可以带一个配置字典,以便不同的pass可以方便地获取passed的数据,例如回退设备信息和循环展开的步数/深度等。 为了能够获取所需的配置,必须通过TVM_REGISTER_PASS_CONFIG_OPTION注册关键字。 例如,loop unrolling pass使用以下内容:
TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);
更多细节请参考 src/tir/transforms/unroll_loop.cc。
Pass Objects
Pass 是所有 pass 对象的基类。 这里的所有方法都只是在后端实现的简单包装器。 它们是为了用户方便地与 Python 中的基类进行交互而定义的。 在 pass 基类中只定义了一个__call__来使子类成为可调用对象,以便它们可以很容易地被调用(例如 pass_xx(arg))来执行。
@register_relay_node
class Pass(RelayNode):
def __call__(self, mod):
return _transform.RunPass(self, mod)
提供了一些辅助 APIs 以支持从 Python 前端轻松创建pass并让pass infra控制执行。 比如提供给用户module_pass、function_pass、sequential,让他们可以自定义自己的pass或者pass管道。
对于在C++后端实现的所有pass,我们分别在python/tvm/ir/transform.py和python/tvm/relay/transform.py中提供了相应的Python API。 例如,const 折叠有一个 Python API,如下所示:
def FoldConstant():
return _transform.FoldConstant()
用户可以通过装饰器像下面这样构建一个pass:
@relay.transform.module_pass(opt_level=2)
def transform(mod, ctx):
tp = relay.TensorType((10

最低0.47元/天 解锁文章
560





