【从零开始学深度学习编译器】七,万字长文入门TVM Pass

0x0. 前言

这篇文章基于TVM 0.8.0.dev版本。在【从零开始学深度学习编译器】五,TVM Relay以及Pass简介 这篇推文中已经简单介绍了Relay和Pass机制。但对Pass的基础设施(Pass Infrastructure)和Relay树结构都没有详细介绍,所以这篇文章主要介绍一下Pass Infrastructure和Relay树结构,再基于这些关键的基础知识详细了解一下Constant Folding Pass,相信读者读完这篇文章会对TVM的Pass有更深的理解,并且在阅读其它Pass和实现自定义Pass时可以很Relax。

0x1. Pass Infrastructure

首先来看Pass Infrastructure,基于官方文档进行介绍。

在讲解Pass通用的注册和运行流程前,先来介绍一下TVM的Pass Infrastructure。参考官方文档:https://tvm.apache.org/docs/dev/pass_infra.html 。

Relay 和 TVM IR 都包含一系列优化passes,可提高模型的性能指标,例如平均推理速度、内存占用或特定设备的功耗。 TVM有一套标准优化方法以及特定于机器学习的优化方法,包括常量折叠、死代码消除、运算符布局更改、算符融合、缓冲区处理和循环变换等。 每一个Pass都使用在traversal期间和/或之前收集的分析结果来构造ir-to-ir的pass。

然而,随着TVM的迅速发展,需要一种更系统、更有效的方法来管理这些passes。此外,一个可以管理跨TVM堆栈不同层(如Relay和tir)的passes的通用框架,为开发人员快速原型化并将实现的passes插入系统铺平了道路。

例如,许多现有的生产编译器,如 GCC 和 LLVM,都采用pass manager来有效管理passes的执行。 最初管理 pass 很简单,因为 pass 的数量很少,但成熟的编译器将包含数百个单独的 pass。 Often external users will want to have custom passes correctly scheduled without having to modify a single handcrafted pass order.

同样,现代深度学习框架,如 Pytorch 和 MXNet Gluon,也有分别通过 Sequential 和 Block 启用pass-style层构建方案的趋势。 有了这样的结构,这些现代框架能够方便地将模块/层添加到它们的容器中,并轻松地构建神经网络。

Relay pass infra 的设计很大程度上受到 LLVM 中使用的分层pass manager和流行的深度学习框架中使用的block-style容器的启发。 pass infra 的主要目标包括:

  • 实现更好的optimizer编程编排。 这允许用户灵活地定制和构建自己的优化管道。
  • 提供一种用户友好的方式来调试passes。
  • 减轻开发人员手动和分别解决passes之间的依赖关系。
  • 为开发人员简化实现新passes的难度。 例如,我们允许用户在 Python 中实现一个 pass 并让 pass infra 操纵它的执行。

The Design

我们专注于为用户提供易于扩展的功能,让用户可以快速添加新passes而不会失去向后兼容性。 该设计包含后端和前端。 前者实现了 pass infra 的主要逻辑。 后者为用户提供简单的 API 进行交互,即允许用户快速创建自己的优化管道。

C++ Backend

我们提供了一个 PassInfo 对象来包含一个pass所需的基本信息。 name 是 pass 名称,opt_level 指示将启用 pass 的优化级别, required 表示执行某个 pass 所需的 pass(更多详细信息请参见include/tvm/ir/transform.h)。 例如,在注册pass的时候(将在后面介绍),pass开发人员可以指定pass的名称、将执行的优化级别和/或所需的pass。 opt_level 可用于帮助 pass infra 识别在用户提供的优化级别下运行时是否需要执行某个 pass。 required字段可以由pass infra用来解决pass依赖关系。

class PassInfoNode : public Object {
   
   
  String name;
  int opt_level;
  Array<String> required;
};

PassContext

PassContext 带有用于优化pass的有用信息。 例如,它包含错误报告系统,因此pass的作者可以提供有关优化失败原因的注释。 PassContext 还旨在替换旧的BuildConfig,它用于帮助用户配置编译选项,包括优化级别和必需/禁用的pass等。例如,我们可能有一个配置,它在 opt_level=3 时执行所有pass,除开使用 PassContext 提供的 disabled_pass=xx禁用的一些passes 。 现在我们可以在 opt_level=3 处对所有passes进行全局处理,并排除禁用pass列表中的那些pass。

这个类是为方便用户编写Python而设计的,它的语法可以在特定的配置下执行优化。 此外,用户可以通过 PassContext::Current()以线程安全的方式获取某个程序范围内可用的context,因为ThreadLocalStore用于保存创建的pass context对象,关于ThreadLocalStore建议看这篇文章:https://zhuanlan.zhihu.com/p/61587053,TVM模仿Java中的ThreadLocalStore在C++层自己实现了用来管理线程。 稍后将提供示例以展示我们如何使用 C++ 和 Python API 来创建使用pass context的编译管道。

class PassContextNode : public Object {
   
   
 public:
  ErrorReporter err_reporter;
  int opt_level{
   
   2};
  tvm::Array<tvm::Expr> required_pass;
  tvm::Array<tvm::Expr> disabled_pass;
};

class PassContext : public NodeRef {
   
   
 public:
  TVM_DLL static PassContext Create();
  TVM_DLL static PassContext Current();
  /* Other fields are omitted. */

 private:
  // The entry of a pass context scope.
  TVM_DLL void EnterWithScope();
  // The exit of a pass context scope.
  TVM_DLL void ExitWithScope();

  // Classes to get the Python `with` like syntax.
  friend class tvm::With<PassContext>;
};

struct PassContextThreadLocalEntry {
   
   
  /*! \brief The default pass context. */
  PassContext default_context;
  /*! \brief The current pass context. */
  std::stack<PassContext> context_stack;
  PassContextThreadLocalEntry() {
   
   
    default_context = PassContext(make_node<PassContextNode>());
  }
};

/*! \brief The thread-local store to hold the pass context. */
typedef dmlc::ThreadLocalStore<PassContextThreadLocalEntry>
     PassContextThreadLocalStore;

Pass Constructs

pass infra 是以分层方式设计的,它可以在不同粒度的Relay/tir 程序下工作。 引入了一个纯虚拟类 PassNode 作为不同优化pass的基础。 此类包含几个必须由子类在modules, functions, or sequences of passes实现的虚拟方法。

class PassNode : Object {
   
   
  virtual PassInfo Info() const = 0;
  virtual Module operator()(const IRModule& mod
                            const PassContext& pass_ctx) const = 0;
};

成员函数展示了一个pass应该如何实现,例如它始终在特定context下工作在 IRModule中,所有的pass都被设计在一个Module to Module的管理器中。因此,由 pass infra 控制的优化将始终更新整个module。

已经创建了几个子类来实现不同类型的优化pass,例如,function-level passes, module-level passes, and sequential passes。 每个子类本身都可以充当pass管理器。 例如,他们可以收集所需的passes并执行它们或基于给定的元数据构建依赖关系图。 它们的完整定义可以在src/relay/ir/transform.cc 和 src/ir/transform.cc 中找到。

Module-Level Passes

Module Level Passes主要用于全局和过程间优化 (IPO),类似于 LLVM 中使用的module pass。 Relay 中一些典型的 pass 需要一个模块的global picture,比如 A-normal form conversion 和 lambda lifting等,都属于这个集合。 在此级别,用户甚至可以在一个module中添加和/或删除function。

class ModulePassNode : PassNode {
   
   
  PassInfo pass_info;
  runtime::TypedPackedFunc<Module(Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  // Other members/methods are omitted
};

pass_info 维护module-level pass所需的信息。 pass_func 实现了真正的optimization。 例如,我们可能需要对module执行死代码消除。 我们可以在 pass_func 中实现算法并让它在module上运行。 然后它将删除死代码,包括module中未使用的函数。 请注意,该字段被设计为一个packed function,所以这个优化不仅可以使用C++还可以使用Python来实现。

Function-Level Passes

Function-level passes用于为给定的 Relay/tir module实现各种内部函数级优化。 它一次从module的函数列表中获取一个函数以进行优化,并生成一个重写的 Relay Functiontir PrimFunc。 大多数pass可以归入这一类,例如Relay中的常见子表达式消除和inference simplification 以及tir中的向量化和flattening storage等。

请注意,此级别的passes范围是 Relay Function或 tir PrimFunc。 因此,我们无法通过这些passes添加或删除函数,因为它们不知道全局信息。

class FunctionPassNode : PassNode {
   
   
  PassInfo pass_info;
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
  bool SkipFunction(const Function& func) const;
  // Other members/methods are omitted...
};

pass_info 与我们刚刚在Module pass 中描述的相同。 pass_func 需要一个函数进行优化,它还需要一个Module,因为我们可能会使用它来报告错误。 一个函数可以用“SkipOptimization”注释,以便在优化过程中被忽略。

Sequential Passes

SequentialPass 类似于 Pytorch nn.Sequential,它包含许多用于执行的passes。

class SequentialPassNode : PassNode {
   
   
  PassInfo pass_info;
  // Passes need to be executed.
  Array<Pass> passes;
  bool PassEnabled(const PassInfo& info) const;
  Module operator()(const Module& mod, const PassContext& pass_ctx) const final;
};

目前在Relay中只有少数passes 被放入这组中。 例如,FoldScaleAxis 需要在内部调度 ForwardFoldScaleAxisBackwardFoldScaleAxis。 此外,建议先完成BackwardFoldScaleAxis。 因此,该pass是SequentialPass的理想候选者。

以下代码显示了如何调用sequential pass中的各个pass。

Module SequentialNode::operator()(const Module& module,
                                  const PassContext& pass_ctx) const {
   
   
  Module mod = module;
  for (const Pass& pass : passes) {
   
   
    ICHECK(pass.defined()) << "Found undefined pass for optimization.";
    const PassInfo& pass_info = pass->Info();
    if (!PassEnabled(pass_info))  continue;
    for (const auto& it : pass_info->required) {
   
   
      const auto* name = it.as<tvm::ir::StringImm>();
      ICHECK(name);
      mod = GetPass(name->value)(mod, pass_ctx);
    }
    mod = pass(mod, pass_ctx);
  }
  return mod;
}

在调用pass时,我们首先检查是否启用了此pass。 这是通过首先检查用户是否明确禁用该pass,然后检查它是否被用户指定为必需pass来完成的。 如果仍然不确定是否启用了此传递,则将检查其 opt_level。 只有当它的opt_level不低于pass context中配置的优化级别时,才会启用并因此执行此pass。

要执行pass,我们首先需要使用pass name在 TVM packed function注册表中已注册的pass。 这是可能的,因为每个pass都注册了一个 API 接口,我们将在后面展示。

Pass GetPass(const std::string& pass_name) {
   
   
  using tvm::runtime::Registry;
  std::string fpass_name = "relay._transform." + pass_name;
  const auto* f = Registry::Get(fpass_name);
  ICHECK(f != nullptr) << "Cannot find " << fpass_name
                      << "to create the pass " << pass_name;
  return (*f)();
}

提供了一些helper function来创建上述每种类型的Pass。 这些helper function也暴露给 Python 前端,以便用户可以方便地使用 Python API 来创建特定的 pass 对象。

Pass CreateFunctionPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreatePrimFuncPass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass CreateModulePass(
    const runtime::TypedPackedFunc<PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
    int opt_level,
    String name,
    Array<String> required);

Pass Sequential(tvm::Array<Pass> passes, PassInfo pass_info);

Pass Registration

我们已经介绍了不同级别pass的概念和用于编译的context。 用户可以多么轻松地注册pass是一件有意义的事。,我们以constant folding为例。 这个 pass 已经被实现来折叠 Relay Function中的常量(在 tvm/src/relay/transforms/fold_constant.cc 中找到)。

提供了一个 API 来执行 ExprExpr 的转换。

Expr FoldConstant(const Expr& expr);

为了将这个pass注册到pass infra,我们首先需要决定这个pass将在哪个级别执行。 由于常量折叠发生在单个函数上,我们应该直观地通过 CreateFunctionPass为其创建一个 FunctionPasspass_func 作为packed function返回,该函数在 IRModule 中的每个function上调用 Expr to Expr API。 {} 表示此pass不需要先决条件。 否则,pass开发人员必须识别并列出它们。

namespace transform {
   
   

Pass FoldConstant() {
   
   
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
    [=](Function f, IRModule m, PassContext pc) {
   
   
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {
   
   });
}

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

}  // namespace transform

为了允许其他 C++ 模块应用此pass,我们在 include/tvm/relay/transform.h中声明了一个free function,如下所示:

TVM_DLL Pass FoldConstant();

Python Frontend

python前端只需要一些简单的 APIs。 例如,我们可以为用户提供以下 APIs 来创建和执行一个 pass(完整的实现在 python/tvm/relay/transform.pypython/tvm/ir/transform.py 中提供)。 后端接收信息并决定它应该使用哪个函数来创建 Pass 对象。

PassContext

Python 前端为 PassContext 提供了一个包装器,通过覆盖 __enter____exit__ 来启用 with 语法。 为用户提供了一个 current 静态方法来获取在特定范围内使用的上下文。

@tvm._ffi.register_object("transform.PassContext")
class PassContext(tvm.runtime.Object):
    def __enter__(self):
        _transform.EnterPassContext(self)
        return self

    def __exit__(self, ptype, value, trace, config):
        _transform.ExitPassContext(self)

    @staticmethod
    def current():
        """Return the current pass context."""
        return _transform.GetCurrentPassContext()

PassContext 用于配置编译选项,包括优化级别和必需/禁用的pass。 它还可以带一个配置字典,以便不同的pass可以方便地获取passed的数据,例如回退设备信息和循环展开的步数/深度等。 为了能够获取所需的配置,必须通过TVM_REGISTER_PASS_CONFIG_OPTION注册关键字。 例如,loop unrolling pass使用以下内容:

TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig);

更多细节请参考 src/tir/transforms/unroll_loop.cc

Pass Objects

Pass 是所有 pass 对象的基类。 这里的所有方法都只是在后端实现的简单包装器。 它们是为了用户方便地与 Python 中的基类进行交互而定义的。 在 pass 基类中只定义了一个__call__来使子类成为可调用对象,以便它们可以很容易地被调用(例如 pass_xx(arg))来执行。

@register_relay_node
class Pass(RelayNode):
   def __call__(self, mod):
       return _transform.RunPass(self, mod)

提供了一些辅助 APIs 以支持从 Python 前端轻松创建pass并让pass infra控制执行。 比如提供给用户module_passfunction_passsequential,让他们可以自定义自己的pass或者pass管道。

对于在C++后端实现的所有pass,我们分别在python/tvm/ir/transform.pypython/tvm/relay/transform.py中提供了相应的Python API。 例如,const 折叠有一个 Python API,如下所示:

def FoldConstant():
    return _transform.FoldConstant()

用户可以通过装饰器像下面这样构建一个pass:

 @relay.transform.module_pass(opt_level=2)
 def transform(mod, ctx):
    tp = relay.TensorType((10
### SAM模型概述 SAM(Segment Anything Model)是一种由Meta开发的通用分割模型,旨在解决图像中的任意目标分割问题。它通过习一组可以泛化到新类别和场景的目标表示来实现这一目的[^1]。 #### 工作原理 SAM的核心理念在于其能够生成高质量的掩码(masks),这些掩码用于精确描述输入图像中特定区域的内容。具体来说: - **编码器部分**:SAM利用了一个强大的视觉Transformer作为骨干网络,该网络负责提取高层次特征并理解整个图像语义信息。 - **解码器部分**:基于来自编码器的信息以及用户的提示(prompts),如点击位置或者边界框等简单指示,解码器会生成最终所需的像素级精度的分割结果[^2]。 这种设计使得即使是在未见过的数据集上也能表现出色,因为模型已经会了如何根据不同类型的提示去适应各种可能的任务需求。 #### 应用场景 由于其灵活性与高效性,SAM适用于多个领域内的实际应用案例之中: 1. **医疗影像分析**:通过对医扫描图片进行精准标注从而辅助医生诊断疾病状态; 2. **自动驾驶技术**:帮助车辆识别道路上行人、障碍物及其他重要元素以便做出安全决策; 3. **增强现实/虚拟现实环境构建**:允许开发者更方便快捷地创建交互式的三维空间体验产品原型。 以下是关于如何加载预训练权重的一个Python代码片段示例: ```python import torch from segment_anything import sam_model_registry, SamPredictor device = 'cuda' if torch.cuda.is_available() else 'cpu' sam_checkpoint = "path/to/sam_vit_h_4b8939.pth" model_type = "vit_h" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device) predictor = SamPredictor(sam) image_path = "./example.jpg" input_image = cv2.imread(image_path) input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB) predictor.set_image(input_image) ```
评论 3
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值