深度分析算子融合(TVM)

基本概念

算子

在计算机科学和特别是深度学习和机器学习领域,算子(Operator)或称为操作(Operation),是指在一个或多个数据上执行特定计算或变换的函数或算法。算子是构建和训练神经网络模型的基本构建块,它们定义了数据如何被处理和转换。

算子的类型

  1. 算术算子:包括加法(Add)、减法(Subtract)、乘法(Multiply)、除法(Divide)等,用于执行基本的数学运算。
  2. 线性代数算子:如矩阵乘法(Matrix Multiplication)、转置(Transpose)、求逆(Inverse)等,用于处理向量和矩阵。
  3. 卷积算子:在卷积神经网络(CNN)中,卷积算子用于提取图像和其他高维数据的特征。
  4. 激活函数:如ReLU(Rectified Linear Unit)、Sigmoid、Tanh等,用于引入非线性因素,帮助网络学习复杂的模式。
  5. 池化算子:在CNN中用于降低特征图的空间尺寸,减少参数数量和计算量,如最大池化(Max Pooling)和平均池化(Average Pooling)。
  6. 归一化算子:如批量归一化(Batch Normalization)、层归一化(Layer Normalization)等,用于调整神经网络中间层的输出,提高训练的稳定性和性能。
  7. 损失函数:如交叉熵损失(Cross-Entropy Loss)、均方误差(Mean Squared Error)等,用于在训练过程中评估模型的性能。
  8. 优化算子:如梯度下降(Gradient Descent)、Adam、RMSprop等,用于更新模型的权重以最小化损失函数。

算子融合

算子融合是一种优化技术,它通过将多个计算操作合并为一个单一计算单元,以减少内存访问次数和提高计算效率。这种方法旨在解决深度学习模型推理中的内存瓶颈和并行处理限制,即所谓的“内存墙”和“并行墙”问题。算子融合通常分为两种类型:水平融合和垂直融合。

  • 水平融合:涉及将同一层级中的多个算子合并,以减少内存读写操作和提高数据局部性。
  • 垂直融合:则是将不同层级中的算子合并,以优化整个模型的执行流程和减少中间数据的传输。

通过算子融合,可以减少内存访问延迟,提高缓存利用率,并减少执行过程中的内存移动,从而加快模型的推理速度。这种技术对于优化深度学习模型的性能至关重要,特别是在资源受限的环境中。
在这里插入图片描述

后文将以TVM为基础,介绍具体的算子融合技术。

tvm 中将算子分为7种类型:

  1. kElemWise:2个 tensor 之间按元素逐个操作的算子,实际上所有四则运算都是这种类型
  2. kBroadcast:见上述链接,到操作两个不同形状的 tensor 时
  3. kInjective:一对一映射函数,比如 add / sqrt / exp 等操作算子(operator)
  4. kCommReduce:多到少的映射,输入到输出具有降维性质,如:sum / max / min等操作操作算子(operator)
  5. kOutEWiseFusable:这是计算比较复杂的,如:conv2d / bn / relu等操作算子(operator)
  6. kTuple:元组节点的模式。可以融合到后续注入操作中, 但需要特殊处理
  7. kOpaque:无法被融合的算符,比如 sort

基于算子的分类,TVM制定了完整的算子融合规则,能够自动识别算子融合组合。基于规则融合在一起的算子,将不会有多余的load-store指令执行,而是将中间数据直接传递给下一层的算子,减少内存的读取开销。并且在硬件性能提升的现在,有效的组织算子进行融合能够最大限度的利用好硬件的并行能力。

1059417-20240421174038725-362674089

支配树

支配树是TVM算子重要的数据结构,能够快速的找到算子的直接后支配点,从而识别是否能够进行融合。

在一个有向无环图中,对于一个节点n来说,从初始节点s出发到达n的所有路径都经历一个节点m,那么m就是n的支配点。而距离n最近的支配点被称作立即支配点。以r为树根,将所有立即支配点按照支配关系连接起来就形成了支配树。立即后支配点是从一个点n出发所有到终止节点的路径中通过的最近节点,形成的支配树是后支配树。

有向图转后序支配树

具体实现

根据TVM论文描述,算子融合分为三个步骤:

  1. 通过relay IR 构建数据流图的 DAG 以进行支配分析
  2. 构造一个后支配树,给出每个节点的直接后支配者。
  3. 使用给定的后支配信息运行融合算法

构建DAG

  Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
    // setup the group map.
    auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
    auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)
                      .Partition(graph);
    for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
      ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
      gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
    }
    // The following line can be used for debug.
    // this->DebugDumpGroup(body);
    return this->Mutate(body);
  }

构建DAG的代码位:

 auto graph = IndexedForwardGraphCreator::Create(&arena_, body); //其中body 是 relay ir的结构,传递到 此处已经是一个fuction node

IndexedForwardGraphCreator类中 声明了 该Node节点存储了引用对象ref,拓扑序index,算子类型pattern,是否被引用extern_ref以及与节点输出的边outputs这些信息;IndexedForwardGraph还存储了对象和节点的映射关系node_map,所有节点的post-dfs遍历顺序post_dfs_order。 简单理解就是该类做了一个数据结构的转换,将relayIR转为Graph nodeIR。

  // The output.
  IndexedForwardGraph graph_;
class IndexedForwardGraph {
 public:
  struct Node;
  /*!
   * The forward edge in the dataflow graph.
   */
  struct Edge {
    /*! \brief The corresponding node */
    Node* node{nullptr};
    /*! \brief The respective pattern of this op */
    OpPatternKind pattern{kOpaque};
  };
  /*! \brief A node in the graph. */
  struct Node {
    /*! \brief weak reference to the corresponding edge. */
    const tvm::Object* ref{nullptr};
    /*! \brief The index of the node in topological order. */
    size_t index{0};
    /*! \brief Whether this node is referenced by external source */
    bool extern_ref{false};
    /*! \brief The general pattern in the node */
    OpPatternKind pattern{kOpaque};
    /*! \brief The outputs of the node. */
    LinkedList<Edge> outputs;
  };
  /*! \brief The node map that maps node to graph */
  std::unordered_map<const tvm::Object*, Node*> node_map;
  /*! \brief All the nodes in post DFS order */
  std::vector<Node*> post_dfs_order;
  }

IndexedForwardGraphCreator 继承 ExprVisitor,主要对 FunctionNodeCallNodeConstantNode等节点的遍历进行重写

FunctionNode举例:

  // Post order tree
  void VisitExpr_(const FunctionNode* op) final {
    // Skip the function that should be handled by external codegen.
    if (op->GetAttr<String>(attr::kCompiler).defined()) return;

    for (auto param : op->params) {
      this->Update(param, nullptr, kOpaque);
    }
    this->Update(op->body, nullptr, kOpaque);
    ExprVisitor::VisitExpr_(op);
  }

然后会调用Update函数,为graph 创建或更新node操作,如果有parent 参数,需要用创建edge

  void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
    const tvm::Object* key = node.get();
    IndexedForwardGraph::Node* current;
    auto it = graph_.node_map.find(key);
    if (it != graph_.node_map.end()) {
      current = it->second;
    } else {
      current = arena_->make<IndexedForwardGraph::Node>();
      graph_.node_map[key] = current;
    }
    if (parent != nullptr) {
      auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
      link->value.node = parent;
      link->value.pattern = pattern;
      current->outputs.Push(link);
    } else {
      current->extern_ref = true;
    }
  }

遍历完算子的参数,再去对算子本身进行Update,但是算子本身是一个CallNode(TVM中是如此规定)

  void VisitExpr_(const CallNode* call) final {
    ICHECK(graph_.node_map.count(call));
    IndexedForwardGraph::Node* node = graph_.node_map.at(call);
    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
    // Now we set the pattern of this call.
    //
    // If we see a call mentioning an operator we should mark it with its
    // annotated pattern.
    //
    // If the pattern is not annotated we will default to opaque.
    //
    // Finally if the operator position is not a call node we will
    // need to call Update, as it may be an arbitrary expression.
    OpPatternKind op_pattern = kOpaque;
    if (auto optional = call->op.as<Op>()) {
      auto op = optional.value();
      if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
        // output of a shape func can't be fed to a data-dependent shape func
        op_pattern = kOpaque;
      } else {
        op_pattern = static_cast<OpPatternKind>(fpattern[op]);
      }
    } else {
      this->Update(call->op, node, kOpaque);
    }

    node->pattern = op_pattern;
    this->Update(call->op, nullptr, kOpaque);
    const auto* rtype = call->checked_type().as<TensorTypeNode>();
    // pass the analysis back to all the children it references.
    for (size_t i = 0; i < call->args.size(); ++i) {
      const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
      // specifically check if result type is the same as arguments type
      OpPatternKind edge_pattern = op_pattern;
      if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
          attr_equal_(rtype->shape, arg_type->shape)) {
        edge_pattern = kElemWise;
      }
      this->Update(call->args[i], node, edge_pattern);
    }
    ExprVisitor::VisitExpr_(call);
    this->AddNode(call);
  }

CallNodeVisitExpr_中首先确定节点的类型,如输入节点是add则将其分类为kElemWise类型,并将节点加入graph。接下来处理输入的args,此处会判断如果输入args的shape和返回值shape一致,则将edge类型从kBroadcast转换为kElemWise,之后更新到arg节点,建立arg到CallNode(Call(Add, …))的边,如下图第一阶段处理所示;

接下来进行VisitExpr_的递归阶段,来完善该callNode节点的args,如下图在第一阶段之后,将递归找寻%2节点的分支,更新完%2节点,再更新%3,直到完全构造出DAG

构建DAG的全流程图解:

  1. 后序遍历计算图,保存到 post_dfs_order 里面,由于遍历是从计算图出口开始的,而且是后序遍历,所以post_dfs_order最后一个保存的就是计算图的出口节点
  2. 推断节点类型,OpPatternKind,并调用 Update() 函数把节点类型保存到 Node 结构里

img

构建支配树

接下来看后序支配树的构建。构建后支配树的目的主要是为了能快速找出任一节点的直接后支配点

因为根节点(DAG图的出口)在post_dfs_order中最后,所以从根节点开始寻找每个节点出点的LCA,这个LCA就是后序支配点。

    auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)
                      .Partition(graph);

// 
std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
    const IndexedForwardGraph& graph) {
  this->InitGroups(graph);
  if (opt_level_ == 0) return std::move(groups_);
  // get post dominator tree
  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
  // run fusion algorithm.
	...
}

DominatorTree的数据结构如下:

/*!
 * \brief Dominator tree that represent domination or
 *  post domination relation of the node.
 */
class DominatorTree {
 public:
  /*!
   * \brief A node in the dominator tree.
   */
  struct Node {
    /*! \brief The node in the tree */
    IndexedForwardGraph::Node* gnode{nullptr};
    /*! \brief parent of the tree */
    Node* parent{nullptr};
    /*! \brief current depth*/
    int depth{0};
    /*! \brief aggregated pattern to parent */
    OpPatternKind pattern{kOpaque};
  };
  // index -> node.
  std::vector<Node*> nodes;
  .....
  }

此处定义的支配树包括了index到节点的映射,节点包括以下字段,填充这些数据结构即完成了Graph -> DominatorTree数据结构的转换

  • gnode:相对Graph的节点引用
  • parent:父节点
  • depth:深度,方便计算LCA
  • pattern:算子类型

构建后序支配树,是根据逆向拓扑排序来处理graph中的节点,通过getNode来获取节点信息

DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
                                            IndexedForwardGraph::Node* gnode) {
  Node* tnode = arena->make<Node>();
  tnode->gnode = gnode;
  if (gnode->extern_ref) {
    tnode->depth = 1;
    tnode->parent = nullptr;
    tnode->pattern = kOpaque;
  } else {
    // find the LCAs of all outputs.
    OpPatternKind pattern = kElemWise;
    Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
    tnode->depth = parent ? parent->depth + 1 : 1;
    tnode->parent = parent;
    tnode->pattern = pattern;
  }
  return tnode;
}

其中LeastCommonAncestor 是最小公共祖先算法(LCA算法),可自行了解。在LeastCommonAncestor中,TVM还通过CombinePattern返回两个算子类型中更不容易融合的类型。

具体的构建流程为:

  1. 根据步骤1生成DAG的post_dfs_order,来获取节点
  2. 通过获取节点的父亲节点,并将其标记节点在后序支配树中的子节点即可
  3. 如果存在已经在后序 支配树中存在父亲节点的节点(或者说是在DAG中有多个孩子节点的节点),那么需要将找到这些节点的最小公共祖先来作为冲突节点的父节点。如上图的节点2,有两个孩子节点(节点4 和节点7),这样在后序支配树生成的时候,节点4和节点7 都需要让自己成为节点2 的父亲节点,这是和树的结构冲突的。所以需要寻找节点4和节点7的公共祖先节点(节点8),让其成为节点2在后序支配树中的父亲节点
  4. 构建完成 后序支配树

进行融合

完成支配树构建之后,就可以开始融合操作了。

整个融合分为3个阶段,每个阶段执行不同的融合操作,具体逻辑都在 RunFuse() 函数里面

std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
    const IndexedForwardGraph& graph) {
  this->InitGroups(graph);
  if (opt_level_ == 0) return std::move(groups_);
  // get post dominator tree
  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
  // run fusion algorithm.
  for (int phase = 0; phase < 3; ++phase) { this->RunFuse(graph, post_dom_tree, phase);
  }
  return std::move(groups_);
}

其中融合算法主要是分别遍历dag,postDominator tree,以及group图中节点,来判断算符是否能被融合。注意这次遍历,是从计算图入口而不是出口开始遍历。

不急,接下来我们进行代码详解:

首先,我们对计算图中所有的节点进行分组,这样可以快速判断两个节点是否属于同一个分组。其中Group的是Union find的数据结构。

// 初始化 graph的分组
this->InitGroups(graph);
//gropu的数据结构
  struct Group {
    /*! \brief The parent in the union find data structure. */
    Group* parent{nullptr};
    /*! \brief The pattern of the group */
    OpPatternKind pattern;
    /*! \brief reference to the root node. */
    const tvm::Object* root_ref{nullptr};
    /*!
     * \brief Reference to the anchor node,
     * this field is not nullptr only if pattern is kOutEWiseFusable.
     */
    const tvm::Object* anchor_ref{nullptr};
    /*!
     * \brief The number of nodes belonging to this group
     */
    uint32_t num_nodes{1};
    /*!
     * \brief The number of function arguments belonging to this group
     */
    size_t args_num{0};

    /*! \brief Optional attributes to annotate the grouped function. */
    runtime::Map<runtime::String, ObjectRef> attrs;
    /*!
     * \brief Find the group root, perform path compression
     * \return The root type node.
     */
    Group* FindRoot();
  };

知道分组之后,我们再来详解算子融合的三个阶段:

  1. 阶段1:处理kOutEltwiseFusable
  2. 阶段2:处理 kInjective 或 kTuple,
  3. 第三阶段尝试将patten<=kInjective的算子融入kTuple
  4. 每一阶段都会处理kElemWise和kBroadcast:当前节点与其后支配点中的任意节点都满足patten<=kInjective且后支配点满足patten<=kOutEWiseFusable则可以融合;
// 首先获取,当前节点的组号,在计算图 中的节点,在后序支配树中的节点
		auto* graph_node = graph.post_dfs_order[nid];
    auto* dom_node = post_dom_tree.nodes[nid];
    Group* group_node = groups_[nid];
    ICHECK(group_node != nullptr);
    postpone_node_ = nullptr;
    // 获取该节点后支配点graph索引
    size_t dom_parent_gindex = dom_node->parent->gnode->index;

阶段1

    if (group_node->pattern == kOutEWiseFusable) {
      if (phase != 0) continue;
      // Path for OutEWiseFusable: conv2d
      // Check if the dominator relation is elemwise.
      if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
        ICHECK(dom_node->parent->gnode != nullptr);
        // The fuse can be executed if all the intermediate ops are still broadcast.
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
    }

当前节点为kOutEWiseFusable,后支配点为kElemWise,且两节点的路径中所有算子均满足patten<=kBroadcast则可以融合;

上诉代码中的 CheckPath是 判断当前节点和后支配节点之间的所有节点是否都满足给定的条件的函数,具体代码如下:

template <typename F>
bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                  F fcond) {
  if (visited_.count(src)) return true;
  visited_.insert(src);
  Group* gnode = groups_[src->index];
  ICHECK(gnode != nullptr);
  gnode = gnode->FindRoot();
  if (!fcond(gnode->pattern, src == sink)) return false;
  if (src == sink) return true;
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    if (!CheckPath_(link->value.node, sink, fcond)) return false;
  }
  return true;
}

template <typename F>
bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                 F fcond) {
  ICHECK(!src->extern_ref);
  visited_.clear();
  ICHECK(src != sink);
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    if (!CheckPath_(link->value.node, sink, fcond)) return false;
  }
  return true;
}

经过 CheckPath的判断,再进行融合CommitFuse函数

阶段2

阶段2 主要是来判断算子为kInjectivekTuple的融合情况。和阶段1的步骤一直,判断算子和 后支配节路径之间的所有节点是否足patten <= kInjective

还是通过CheckPath函数来判断是否满足融合条件,再通过CommitFuse进行融合

if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
      // defer injective fusion to second phase.
      // so conv2d always finishes fusing.
      if (phase != 1) continue;
      // Check if all path are injective.
      auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
      if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
        CommitFuse(graph_node, dom_node->parent->gnode);
      }
    }

阶段3

尝试将patten<=kInjective的算子融入kTuple

  if (phase == 2) {
      // Fuse injective ops into intermediate tuples, if any
      if (group_node->pattern > relay::kInjective) continue;
      Group* dom_parent_group = groups_[dom_parent_gindex];
      Group* dom_root_group = dom_parent_group->FindRoot();
      // If dom node group has a tuple as its root, we do not fuse tuple fields into it
      if (dom_root_group->pattern == relay::kTuple) continue;
      if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
        // Now we know the tuple has been fused into subsequent injective ops
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
        // dom_root_group can also be tuple, as in inception layers
        // CheckPath is needed to avoid fusing two intermediate tuples
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
      continue;
    }

kElemWise和kBroadcast

每一阶段都会处理kElemWise和kBroadcast:当前节点与其后支配点中的任意节点都满足patten<=kInjective且后支配点满足patten<=kOutEWiseFusable则可以融合;

if (group_node->pattern <= kBroadcast) {
      // Pre-condition: can only be fused to parent which is injective or reduction.
      if (dom_node->parent != nullptr &&
          (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
        // Check if all the intermediate ops are still broadcast.
        // The final terminal node can already be fused to a OutEWiseFusable group.
        auto fcond = [](OpPatternKind kind, bool is_sink) {
          if (!is_sink) {
            // Elemwise, broadcast, and injective ops on the parallel branches
            // are allowed be fused to the elemwise/broadcast anchor.
            return kind <= kInjective;
          } else {
            return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
                    kind == kOutEWiseFusable);
          }
        };
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);
        }
      }
    }

注意的是,TVM的算子融合没有对reduce算子 进行任何的处理,只是进行了一个ICHECK的检查

else {
      // do nothing.
      ICHECK(group_node->pattern == kCommReduce);
    }

样例详解

构建了一个简单的算子组合,并对其使用了 算子融合的pass。

        x = relay.var("x", shape=(10, 20))
        y = relay.add(x, relay.const(1, "float32"))
        z = relay.exp(y)
        w = relay.squeeze(z)
        return relay.Function([x], w)

结果如下,通过DAG和后序支配树以及分组的数据结构,经历三个阶段的算子融合最终将所有算子融合在一起,减少了三个算子分离情况下需要反复读取内存中的中间结果这个操作,从而提高了执行效率。(具体提高了多少下次有机会再做分析)

before model:fn (%x: Tensor[(10, 20), float32]) {
  %0 = add(%x, 1f);
  %1 = exp(%0);
  squeeze(%1)
}
fuse model:fn (%x: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */) -> Tensor[(10, 20), float32] {
  %2 = fn (%p0: Tensor[(10, 20), float32] /* ty=Tensor[(10, 20), float32] */, Primitive=1) -> Tensor[(10, 20), float32] {
    %0 = add(%p0, 1f /* ty=float32 */) /* ty=Tensor[(10, 20), float32] */;
    %1 = exp(%0) /* ty=Tensor[(10, 20), float32] */;
    squeeze(%1) /* ty=Tensor[(10, 20), float32] */
  } /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */;
  %2(%x) /* ty=Tensor[(10, 20), float32] */
} /* ty=fn (Tensor[(10, 20), float32]) -> Tensor[(10, 20), float32] */

参考:

[1] https://zhuanlan.zhihu.com/p/589619468

[2] http://0fd.org/2023/06/05/dive-into-tvm-the-relay-pass-of-fuse-ops/

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值