TVM Pass优化 -- 算子融合(FuseOps)


算子融合 就是将多个计算单元合并到一个计算单元里完成计算,减少中间数据读写内存的操作,从而节省计算时间。

  • kElemWise:两个tensor之间按照元素逐个操作的算子,实际上所有的四则运算都是这种类型
  • kBroadcast:带有广播操作的算子
  • kInjective:输入和输出之间具有一对一映射关系的算子,如add/sqrt/exp等操作算子(operator)
  • kCommReduce:多到少的映射,输入到输出就有降维性质,如sum/max/min等操作算子
  • kOutEWiseFusable:这是计算比较复杂的算子,输出可与kElemWise进行fuse的算子,如conv2d/bn/relu等算子
  • kTuple:操作元祖的算子,如TupleNode,TupleGetItemNode等;
  • kOpaque:无法进行融合的算子,如sort

根据TVM论文,TVM提供了 三种融合规则:


import tvm
from tvm import te
import tvm.relay as relay
import numpy as np
from tvm.relay.testing import run_opt_pass

def get_relay_ir():
  shape = (1, 3, 14, 14)
  c_data = np.ones(shape).astype('float32')
  c = relay.const(c_data)

  weight = relay.var('weight', shape=(3, 3, 3, 3))
  x = relay.var('x', relay.TensorType((1, 3, 16, 16), 'float32'))
  conv = relay.nn.conv2d(x, weight)
  y = relay.add(conv, c)
  act = relay.nn.relu(y)

  mul = relay.multiply(conv, relay.const(0.5, 'float32'))
  z = act + mul
  return relay.Function([x, weight], z)

f = get_relay_ir()
mod = tvm.IRModule.from_expr(f)
print('src module:')

mod = run_opt_pass(f, relay.transform.FuseOps(fuse_opt_level=4))


def @main(%x: Tensor[(1, 3, 16, 16), float32], %weight: Tensor[(3, 3, 3, 3), float32]) {
  %0 = nn.conv2d(%x, %weight, padding=[0, 0, 0, 0]);
  %1 = add(%0, meta[relay.Constant][0]);
  %2 = nn.relu(%1);
  %3 = multiply(%0, 0.5f);
  add(%2, %3)

fn (%x: Tensor[(1, 3, 16, 16), float32], %weight: Tensor[(3, 3, 3, 3), float32]) -> Tensor[(1, 3, 14, 14), float32] {
  %4 = fn (%p0: Tensor[(1, 3, 16, 16), float32], %p1: Tensor[(3, 3, 3, 3), float32], %p2: Tensor[(1, 3, 14, 14), float32], Primitive=1) -> Tensor[(1, 3, 14, 14), float32] {
    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]);
    %1 = add(%0, %p2);
    %2 = nn.relu(%1);
    %3 = multiply(%0, 0.5f);
    add(%2, %3) 
  %4(%x, %weight, meta[relay.Constant][0])

根据运行结果,可发现算子融合pass后, conv2d、add、relu和 multiply算子被融合成一个算子,在TVM中为CallNode


算子融合的目的最终是要解决 AI 处理器的内存墙、并行墙的问题,提升 Tensor 数据的访存局部性。


算子融合pass的python入口在 中

def FuseOps(fuse_opt_level=-1):
    """Fuse operators in an expr to a larger operator according to some rules.

    fuse_opt_level : int
        The level of fuse optimization. -1 indicates that the level will be
        inferred from pass context.

    ret : tvm.transform.Pass
        The registered pass for operator fusion.
    return _ffi_api.FuseOps(fuse_opt_level)

TVM通过 packed_func ffi 机制实现了 python 和 c++ 之间的相互调用,其 c++ 后端代码在, 在src/relay/transforms/fuse_ops.cc路径下:

Pass FuseOps(int fuse_opt_level) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        bool link_params = false;
        Executor executor =
        link_params = executor.defined()
                          ? executor->attrs.GetAttr<Bool>("link-params").value_or(Bool(link_params))
                          : link_params;
        link_params = pc->GetConfig("relay.FuseOps.link_params", Bool(link_params)).value();
        int opt_level = fuse_opt_level == -1 ? pc->opt_level : fuse_opt_level;
        auto max_fuse_depth = pc->GetConfig("relay.FuseOps.max_depth", Integer(kMaxFusedOps));
        auto target = Target::Current();
        size_t max_function_args =
                ? target->GetAttr<Integer>("max_function_args", Integer(0)).value().IntValue()
                : 0;
        return Downcast<Function>(FuseOps(f, opt_level, max_fuse_depth.value().IntValue(),
                                          max_function_args, link_params, m));
  return CreateFunctionPass(pass_func, 0, "FuseOps", {"InferType"});


可发现,该pass为Function级别的pass,此处目前只关注 fuse_opt_level优化级别选项即可,可通过passContext进行设置,其余参数暂未用到,使用其默认值即可。



  1. 遍历relay树,建立DAG用于后支配树分析
  2. 构建后支配树,能够快速求取任意节点的后支配点
  3. 根据当前节点的后支配点信息,在两节点路径之间进行融合算法


// Run the transform
  Expr Transform(const Expr& body) {
    return Transform(body, fuse_opt_level_, max_fuse_depth_, link_params_);
  // Run the transform
  Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
    // setup the group map.
    auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
    auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)
    for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
      ICHECK(graph.post_dfs_order[nid]->ref != nullptr);
      gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
    // The following line can be used for debug.
    // this->DebugDumpGroup(body);
    return this->Mutate(body);



auto graph = IndexedForwardGraphCreator::Create(&arena_, body);

其中,arena_ 为内存管理模块, body 为relay的树IR, 此处是一个FunctionNode


class IndexedForwardGraph {
  struct Node;
   * The forward edge in the dataflow graph.
  struct Edge {
    /*! \brief The corresponding node */
    Node* node{nullptr};
    /*! \brief The respective pattern of this op */
    OpPatternKind pattern{kOpaque};
  /*! \brief A node in the graph. */
  struct Node {
    /*! \brief weak reference to the corresponding edge. */
    const tvm::Object* ref{nullptr};
    /*! \brief The index of the node in topological order. */
    size_t index{0};
    /*! \brief Whether this node is referenced by external source */
    bool extern_ref{false};
    /*! \brief The general pattern in the node */
    OpPatternKind pattern{kOpaque};
    /*! \brief The outputs of the node. */
    LinkedList<Edge> outputs;
  /*! \brief The node map that maps node to graph */
  std::unordered_map<const tvm::Object*, Node*> node_map;
  /*! \brief All the nodes in post DFS order */
  std::vector<Node*> post_dfs_order;

Node表示节点,存储了引用对象reg, 拓扑序index, 是否被引用extern_ref, 算子类型pattern以及节点输出边outputs这些信息



post_dfs_order 保存了所有节点的后序遍历节点

该类主要通过IndexedForwardGraphCreator creator对 Relay IR转换为 Graph node 的 IR 数据结构的转换。

IndexedForwardGraphCreator 继承 ExprVisitor,主要对 FunctionNodeCallNodeConstantNode等节点的遍历进行重写
该pass用户传进去的是一个 FunctionNode,因此首先进去 FunctionNode 的处理逻辑:

  // Post order tree
  void VisitExpr_(const FunctionNode* op) final {
    // Skip the function that should be handled by external codegen.
    if (op->GetAttr<String>(attr::kCompiler).defined()) return;

    for (auto param : op->params) {
      this->Update(param, nullptr, kOpaque);
    this->Update(op->body, nullptr, kOpaque);

其逻辑先对参数和函数体进行 Update,之后进入父类的VisitExpr_方法进行递归遍历。

  • Update过程即为Graph中创建或更新Node的操作,如果有parent参数,需要创建Edge,其代码如下:
 // Update the message stored at the node.
  void Update(const Expr& node, IndexedForwardGraph::Node* parent, OpPatternKind pattern) {
    const tvm::Object* key = node.get();
    IndexedForwardGraph::Node* current;
    auto it = graph_.node_map.find(key);
    if (it != graph_.node_map.end()) {
      current = it->second;
    } else {
      current = arena_->make<IndexedForwardGraph::Node>();
      graph_.node_map[key] = current;
    if (parent != nullptr) {
      auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge>>();
      link->value.node = parent;
      link->value.pattern = pattern;
    } else {
      current->extern_ref = true;
  • 父类的 VisitExpr_方法首先访问 FunctionNode的参数:%x%weight, 更新节点信息,%x的拓扑序是0, %weight的拓扑序为1, 且更新了graph的post-dfs顺序:
void ExprVisitor::VisitExpr_(const FunctionNode* op) {
  for (auto param : op->params) {
  void VisitExpr_(const VarNode* op) final { this->AddNode(op); }
    void AddNode(const tvm::Object* key) {
    auto it = graph_.node_map.find(key);
    ICHECK(it != graph_.node_map.end()) << "Cannot find node " << GetRef<ObjectRef>(key);
    IndexedForwardGraph::Node* node = it->second;
    ICHECK(node->ref == nullptr);
    node->ref = key;
    node->index = graph_.post_dfs_order.size();

接下来是访问FunctionNode的函数体body,它是个CallNode 节点,所示:add(%2, %3)

void ExprVisitor::VisitExpr_(const FunctionNode* op) {


  void VisitExpr_(const CallNode* call) final {
    IndexedForwardGraph::Node* node =;
    static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
    // Now we set the pattern of this call.
    // If we see a call mentioning an operator we should mark it with its
    // annotated pattern.
    // If the pattern is not annotated we will default to opaque.
    // Finally if the operator position is not a call node we will
    // need to call Update, as it may be an arbitrary expression.
    OpPatternKind op_pattern = kOpaque;
    if (auto optional = call-><Op>()) {
      auto op = optional.value();
      if (IsDynamic(call->checked_type()) && IsDataDependent(call)) {
        // output of a shape func can't be fed to a data-dependent shape func
        op_pattern = kOpaque;
      } else {
        op_pattern = static_cast<OpPatternKind>(fpattern[op]);
    } else {
      this->Update(call->op, node, kOpaque);

    node->pattern = op_pattern;
    this->Update(call->op, nullptr, kOpaque);


  void VisitExpr_(const CallNode* call) final {
	const auto* rtype = call->checked_type().as<TensorTypeNode>();
    // pass the analysis back to all the children it references.
    for (size_t i = 0; i < call->args.size(); ++i) {
      const auto* arg_type = call->args[i]->checked_type().as<TensorTypeNode>();
      // specifically check if result type is the same as arguments type
      OpPatternKind edge_pattern = op_pattern;
      if (edge_pattern == kBroadcast && arg_type != nullptr && rtype != nullptr &&
          attr_equal_(rtype->shape, arg_type->shape)) {
        edge_pattern = kElemWise;
      this->Update(call->args[i], node, edge_pattern);

接下来处理输入的args,此处会判断如果输入args的shape和返回值shape一致,则将edge类型从kBroadcast转换为kElemWise,之后更新到arg节点,建立arg到CallNode(Call(Add, ...))的边,如下图第一阶段处理所示;

  • 接下来继续进入ExprVisitor::VisitExpr_(call)的CallNode节点处理函数中,依次处理参数(%2, %3)、body,处理参数%2,如图第二阶段;
  • 继续递归处理(post-dfs),如下图第三阶段所示;
  • %2分支更新完,如下图第四阶段;
  • 接下来更新%3分支,直到图被更新完成,如下图第五阶段。





  Expr Transform(const Expr& body, int fuse_opt_level, size_t max_fuse_depth, bool link_params) {
    // setup the group map.
    auto graph = IndexedForwardGraphCreator::Create(&arena_, body);
    auto groups = GraphPartitioner(&arena_, fuse_opt_level, max_fuse_depth, max_function_args_)


std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
    const IndexedForwardGraph& graph) {
  if (opt_level_ == 0) return std::move(groups_);
  // get post dominator tree
  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
  // run fusion algorithm.


  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);


 * \brief Dominator tree that represent domination or
 *  post domination relation of the node.
class DominatorTree {
   * \brief A node in the dominator tree.
  struct Node {
    /*! \brief The node in the tree */
    IndexedForwardGraph::Node* gnode{nullptr};
    /*! \brief parent of the tree */
    Node* parent{nullptr};
    /*! \brief current depth*/
    int depth{0};
    /*! \brief aggregated pattern to parent */
    OpPatternKind pattern{kOpaque};
  // index -> node.
  std::vector<Node*> nodes;

此处定义的支配树包括了index到节点的映射,节点包括以下字段,填充这些数据结构即完成了Graph -> DominatorTree数据结构的转换

  • gnode:相对Graph的节点引用
  • parent:父节点
  • depth:深度,方便计算LCA
  • pattern:算子类型


DominatorTree DominatorTree::PostDom(support::Arena* arena, const IndexedForwardGraph& graph) {
  DominatorTree tree;
  tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
  // reverse topo order
  for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
    size_t index = i - 1;
    tree.nodes[index] = tree.GetNode(arena, graph.post_dfs_order[index]);
  return tree;


DominatorTree::Node* DominatorTree::GetNode(support::Arena* arena,
                                            IndexedForwardGraph::Node* gnode) {
  Node* tnode = arena->make<Node>();
  tnode->gnode = gnode;
  if (gnode->extern_ref) {
    tnode->depth = 1;
    tnode->parent = nullptr;
    tnode->pattern = kOpaque;
  } else {
    // find the LCAs of all outputs.
    OpPatternKind pattern = kElemWise;
    Node* parent = LeastCommonAncestor(gnode->outputs, &pattern);
    tnode->depth = parent ? parent->depth + 1 : 1;
    tnode->parent = parent;
    tnode->pattern = pattern;
  return tnode;


  // Combine pattern together.
  inline static OpPatternKind CombinePattern(OpPatternKind lhs, OpPatternKind rhs) {
    if (lhs > rhs) return lhs;
    return rhs;


DominatorTree::Node* DominatorTree::LeastCommonAncestor(
    const LinkedList<IndexedForwardGraph::Edge>& input_nodes, OpPatternKind* edge_pattern) {
  auto link = input_nodes.head;
  if (link == nullptr) {
    return nullptr;
  auto get_node = [&](const IndexedForwardGraph::Edge& edge) {
    size_t oindex = edge.node->index;
    ICHECK_LT(oindex, nodes.size());
    Node* onode = nodes[oindex];
    ICHECK(onode != nullptr);
    return onode;
  Node* parent = get_node(link->value);
  *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
  link = link->next;
  for (; link != nullptr; link = link->next) {
    parent = LeastCommonAncestor(parent, get_node(link->value), edge_pattern);
    *edge_pattern = CombinePattern(*edge_pattern, link->value.pattern);
  return parent;


  • 节点2有两条输出边,第一条边指向节点4,第二条边指向节点7,首先处理第一条边,处理完成后,parent为节点4,edge_pattern为kEleWise;
  • 接下来处理第二条边,进入以下代码逻辑,根据depth信息找到两节点的最近公共父节点LCA,在此过程中不断更新edge_pattern;
  • 处理完成后,parent为节点8,edge_pattern为kEleWise。
DominatorTree::Node* DominatorTree::LeastCommonAncestor(Node* lhs, Node* rhs,
                                                        OpPatternKind* edge_pattern) {
  while (lhs != rhs) {
    if (lhs == nullptr) return nullptr;
    if (rhs == nullptr) return nullptr;
    if (lhs->depth < rhs->depth) {
      edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
      rhs = rhs->parent;
    } else if (rhs->depth < lhs->depth) {
      edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
      lhs = lhs->parent;
    } else {
      edge_pattern[0] = CombinePattern(edge_pattern[0], lhs->pattern);
      edge_pattern[0] = CombinePattern(edge_pattern[0], rhs->pattern);
      lhs = lhs->parent;
      rhs = rhs->parent;
  return lhs;




std::vector<GraphPartitioner::Group*> GraphPartitioner::Partition(
    const IndexedForwardGraph& graph) {
  // run fusion algorithm.
  for (int phase = 0; phase < 3; ++phase) {
    this->RunFuse(graph, post_dom_tree, phase);
  return std::move(groups_);


 * \brief A partition of the graph marked by union find data structure.
class GraphPartitioner {
  explicit GraphPartitioner(support::Arena* arena, int opt_level, size_t max_fuse_depth,
                            size_t max_function_args)
      : arena_(arena),
        max_function_args_(max_function_args) {}
   * \brief Group as a union find data structure.
  struct Group {
    /*! \brief The parent in the union find data structure. */
    Group* parent{nullptr};
    /*! \brief The pattern of the group */
    OpPatternKind pattern;
    /*! \brief reference to the root node. */
    const tvm::Object* root_ref{nullptr};
     * \brief Reference to the anchor node,
     * this field is not nullptr only if pattern is kOutEWiseFusable.
    const tvm::Object* anchor_ref{nullptr};
     * \brief The number of nodes belonging to this group
    uint32_t num_nodes{1};
     * \brief The number of function arguments belonging to this group
    size_t args_num{0};

    /*! \brief Optional attributes to annotate the grouped function. */
    runtime::Map<runtime::String, ObjectRef> attrs;
     * \brief Find the group root, perform path compression
     * \return The root type node.
    Group* FindRoot();
   * \brief Partition a graph.
   * \return group assignments of each node.
  std::vector<Group*> Partition(const IndexedForwardGraph& graph);

  /*! \brief The internal arena for temporary space. */
  support::Arena* arena_;
  /*! \brief optimization level for fuse operation. */
  int opt_level_;
  /*! \brief The maximum number of operations in one fused function */
  size_t max_fuse_depth_;
  /*! \brief The maximum number of arguments in one fused function */
  size_t max_function_args_;
  /*! \brief The internal groups. */
  std::vector<Group*> groups_;
  /*! \brief internal field used for deduplication */
  std::unordered_set<IndexedForwardGraph::Node*> visited_;
  /*! \brief The map with nodes which were postponed for fusing. */
  std::unordered_multimap<const IndexedForwardGraph::Node*, IndexedForwardGraph::Node*>

Group是一个union Find数据结构(并查集),可以快速的找出两个节点是否属于同一组(分组);


void GraphPartitioner::InitGroups(const IndexedForwardGraph& graph) {
  auto args_counter = [this](const tvm::Object* obj) {
    size_t args_num = 0;
    if (auto call_node = GetRef<ObjectRef>(obj).as<CallNode>()) {
      for (auto& it : call_node->args) {
        if (<VarNode>() ||<TupleGetItemNode>()) {
          if (const auto* ttype =<ExprNode>()->checked_type().as<TensorTypeNode>()) {
            args_num += CountAdditionalArgs_(ttype);
    } else if (auto tuple_node = GetRef<ObjectRef>(obj).as<TupleNode>()) {
      for (auto& it : tuple_node->fields) {
        if (<VarNode>() ||<TupleGetItemNode>()) {
          if (const auto* ttype =<ExprNode>()->checked_type().as<TensorTypeNode>()) {
            args_num += CountAdditionalArgs_(ttype);
    } else if (GetRef<ObjectRef>(obj).as<VarNode>()) {
      if (const auto* ttype =
              GetRef<ObjectRef>(obj).as<ExprNode>()->checked_type().as<TensorTypeNode>()) {
        args_num += CountAdditionalArgs_(ttype);
    return args_num;

  for (size_t nid = 0; nid < groups_.size(); ++nid) {
    const auto* graph_node = graph.post_dfs_order[nid];
    auto* group_node = arena_->make<Group>();
    group_node->pattern = graph_node->pattern;
    group_node->root_ref = graph_node->ref;
    // set anchor ref if necessary.
    if (group_node->pattern == relay::kOutEWiseFusable) {
      group_node->anchor_ref = graph_node->ref;
    group_node->args_num = args_counter(graph_node->ref);
    groups_[nid] = group_node;


void RunFuse(const IndexedForwardGraph& graph, const DominatorTree& post_dom_tree, int phase) {
  for (size_t nid = 0; nid < groups_.size(); ++nid) {
    // 取得graph_node, dom_node和group_node;
    // the group of current node has been specified already.
    auto* graph_node = graph.post_dfs_order[nid];
    auto* dom_node = post_dom_tree.nodes[nid];
    Group* group_node = groups_[nid];
    ICHECK(group_node != nullptr);

    // 遇到不可融合算子kOpaque,直接返回
    if (group_node->pattern == kOpaque) continue;

    // 没有支配点信息的算子直接返回
    if (dom_node->parent == nullptr) continue;

    // 获取该节点后支配点graph索引
    size_t dom_parent_gindex = dom_node->parent->gnode->index;

    // 此处先省略不看
    // refuse the fusion if too many ops are going to be fused together
    if (CountFusedNodesWithNewChild(graph_node, dom_node->parent->gnode) > max_fuse_depth_)

    // 第三阶段处理逻辑(见下文)

    // 当前节点已和其后支配点融合,则跳过
    // Skip if current node is already fused to the parent.
    if (groups_[dom_parent_gindex] != nullptr &&
        group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {

    // 跳过tuple相关操作
    // Do not fuse into tuple for now
    if (groups_[dom_parent_gindex]->pattern == kTuple) continue;

    // 第一阶段处理kOutEltwiseFusable,见下文
    if (group_node->pattern == kOutEWiseFusable) {
    // 每一阶段都会对 kEltwise 或 kBroadcast 处理,见下文
    else if (group_node->pattern <= kBroadcast) {
    // 第二阶段处理 kInjective 或 kTuple,见下文
    else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
    // kCommReduce相关逻辑
    else {
      // do nothing.
      ICHECK(group_node->pattern == kCommReduce);


template <typename F>
bool GraphPartitioner::CheckPath_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                  F fcond) {
  if (visited_.count(src)) return true;
  Group* gnode = groups_[src->index];
  ICHECK(gnode != nullptr);
  gnode = gnode->FindRoot();
  if (!fcond(gnode->pattern, src == sink)) return false;
  if (src == sink) return true;
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    if (!CheckPath_(link->value.node, sink, fcond)) return false;
  return true;

template <typename F>
bool GraphPartitioner::CheckPath(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                 F fcond) {
  ICHECK(src != sink);
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    if (!CheckPath_(link->value.node, sink, fcond)) return false;
  return true;


void GraphPartitioner::MergeFromTo(Group* child, Group* parent) {
  child = child->FindRoot();
  parent = parent->FindRoot();
  if (child == parent) return;
  // update the number of nodes of the parent group
  parent->num_nodes += child->num_nodes;
  parent->args_num += child->args_num;
  child->parent = parent;
  // update anchor ref and pattern
  if (child->anchor_ref != nullptr) {
    ICHECK(parent->anchor_ref == nullptr);
    parent->anchor_ref = child->anchor_ref;
    parent->pattern = CombinePattern(child->pattern, parent->pattern);

void GraphPartitioner::CommitFuse_(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink,
                                   Group* target) {
  if (postpone_node_ != nullptr) {
    postponed_fusing_map_.insert({postpone_node_, src});
  if (src == sink) return;
  if (visited_.count(src)) return;
  Group* gnode = groups_[src->index];
  ICHECK(gnode != nullptr);
  // merge the current group to the parent if possible.
  MergeFromTo(gnode, target);
  for (auto link = src->outputs.head; link != nullptr; link = link->next) {
    CommitFuse_(link->value.node, sink, target);

void GraphPartitioner::CommitFuse(IndexedForwardGraph::Node* src, IndexedForwardGraph::Node* sink) {
  Group* target = groups_[sink->index];
  ICHECK(src != sink);
  CommitFuse_(src, sink, target);


else if (group_node->pattern <= kBroadcast) {
      // Pre-condition: can only be fused to parent which is injective or reduction.
      if (dom_node->parent != nullptr &&
          (dom_node->pattern <= kInjective || dom_node->pattern == kCommReduce)) {
        // Check if all the intermediate ops are still broadcast.
        // The final terminal node can already be fused to a OutEWiseFusable group.
        auto fcond = [](OpPatternKind kind, bool is_sink) {
          if (!is_sink) {
            // Elemwise, broadcast, and injective ops on the parallel branches
            // are allowed be fused to the elemwise/broadcast anchor.
            return kind <= kInjective;
          } else {
            return (kind <= kBroadcast || kind == kCommReduce || kind == kInjective ||
                    kind == kOutEWiseFusable);
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);


// Try to fuse current node to its post-dominator.
    if (group_node->pattern == kOutEWiseFusable) {
      if (phase != 0) continue;
      // Path for OutEWiseFusable: conv2d
      // Check if the dominator relation is elemwise.
      if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
        ICHECK(dom_node->parent->gnode != nullptr);
        // The fuse can be executed if all the intermediate ops are still broadcast.
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kBroadcast; };
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);



else if (group_node->pattern == kInjective || group_node->pattern == kTuple) {
      // defer injective fusion to second phase.
      // so conv2d always finishes fusing.
      if (phase != 1) continue;
      // Check if all path are injective.
      auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
      if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
        CommitFuse(graph_node, dom_node->parent->gnode);

当前节点为kInjectivekTuple且所有到后支配点路径的所有节点均满足patten <= kInjective,则可以融合;


  if (phase == 2) {
      // Fuse injective ops into intermediate tuples, if any
      if (group_node->pattern > relay::kInjective) continue;
      Group* dom_parent_group = groups_[dom_parent_gindex];
      Group* dom_root_group = dom_parent_group->FindRoot();
      // If dom node group has a tuple as its root, we do not fuse tuple fields into it
      if (dom_root_group->pattern == relay::kTuple) continue;
      if (dom_parent_group->pattern == kTuple && dom_root_group->pattern <= relay::kInjective) {
        // Now we know the tuple has been fused into subsequent injective ops
        auto fcond = [](OpPatternKind kind, bool is_sink) { return kind <= kInjective; };
        // dom_root_group can also be tuple, as in inception layers
        // CheckPath is needed to avoid fusing two intermediate tuples
        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
          CommitFuse(graph_node, dom_node->parent->gnode);

其实经过第一阶段的处理,我们的示例已经被完全融合了: 1. 0号节点和1号节点没有parent信息,跳过; 2. 处理2号节点时,其后支配点是8,依次遍历了4、5、8、7号节点(如下图绿色虚线部分),均满足fcond条件,进行了融合,此时2,4,5,7节点的parent均为8;8号节点的num_nodes为5; 3. 当遍历到3号节点(kOpaque)时,其后支配点是4,不满足fcond条件,不融合; 4. 当遍历到6号节点时,其后支配点是7,满足fcond条件,进行融合,6的parent被设为7节点的parant即8,8号节点的num_nodes此时为6; 5. 当遍历其余节点时,均已被fuse,直接返回;



def @main(%x: Tensor[(1, 3, 16, 16), float32] /* ty=Tensor[(1, 3, 16, 16), float32] */, %weight: Tensor[(3, 3, 3, 3), float32] /* ty=Tensor[(3, 3, 3, 3), float32] */) -> Tensor[(1, 3, 14, 14), float32] {
  %4 = fn (%p0: Tensor[(1, 3, 16, 16), float32] /* ty=Tensor[(1, 3, 16, 16), float32] */, %p1: Tensor[(3, 3, 3, 3), float32] /* ty=Tensor[(3, 3, 3, 3), float32] */, %p2: Tensor[(1, 3, 14, 14), float32] /* ty=Tensor[(1, 3, 14, 14), float32] */, Primitive=1) -> Tensor[(1, 3, 14, 14), float32] {
    %0 = nn.conv2d(%p0, %p1, padding=[0, 0, 0, 0]) /* ty=Tensor[(1, 3, 14, 14), float32] */;
    %1 = add(%0, %p2) /* ty=Tensor[(1, 3, 14, 14), float32] */;
    %2 = nn.relu(%1) /* ty=Tensor[(1, 3, 14, 14), float32] */;
    %3 = multiply(%0, 0.5f /* ty=float32 */) /* ty=Tensor[(1, 3, 14, 14), float32] */;
    add(%2, %3) /* ty=Tensor[(1, 3, 14, 14), float32] */
  } /* ty=fn (Tensor[(1, 3, 16, 16), float32], Tensor[(3, 3, 3, 3), float32], Tensor[(1, 3, 14, 14), float32]) -> Tensor[(1, 3, 14, 14), float32] */;
  %4(%x, %weight, meta[relay.Constant][0] /* ty=Tensor[(1, 3, 14, 14), float32] */) /* ty=Tensor[(1, 3, 14, 14), float32] */

FuseMutator继承自MixedModelMutator,并对 FunctionNode, CallNode等的遍历方式进行了重写;
MixedModelMuator的遍历针对dataflow node(如CallNode,TupleNode等)是一个post-topolgy的遍历;

class FuseMutator : private MixedModeMutator {
  int fuse_opt_level_;
  size_t max_fuse_depth_;
  bool link_params_;
  /*! \brief The group assignment map. */
  std::unordered_map<const Object*, GraphPartitioner::Group*> gmap_;
  /* \brief Internal group information map. */
  std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;

首先看下FunctionNode的处理方式:对于primitive function跳过处理,否则进入父类的处理逻辑中,即依次处理args和body;

// Skip primitive function.
Expr VisitExpr_(const FunctionNode* fn_node) {
  if (fn_node->HasNonzeroAttr(attr::kPrimitive)) {
    return GetRef<Expr>(fn_node);
  } else {
    return ExprMutator::VisitExpr_(fn_node);

对于args的处理,依次处理%x和%weight,处理结果存储到其成员变量std::unordered_map<Expr, Expr> memo_成员变量中; 接下来处理body,按照post-topolgy的顺序依次遍历 CallNode(conv, ...) -> ConstantNode -> CallNode(add, ...) -> CallNode(relu, ...) -> ConstantNode -> CallNode(multiply, ...) -> CallNode(add, ...);

对ConstantNode的处理直接继承自父类ExprMutator::VisitExpr_(const ConstantNode* op),在memo_中存储一份引用;

对CallNode的处理是核心如下: 1. 找到当前Call节点所属Group; 2. 构造输入参数:当输入参数所属Group不同于当前Group,则创建形参和实参; 3. 构造对应的CallNode节点; 4. 如果当前节点不是Group的root->ref节点,则直接返回,否则根据GroupInfo中存储的形参和实参构造一个新的FunctionNode和CallNode。


// 存储每一个Group对应的实参和形参
/*! \brief Temporary information from each group. */
struct GroupInfo {
  // The parameters of the function.
  Array<Var> params;
  // The arguments to call the functions.
  Array<Expr> arguments;
  // Get a new parameter or allocate an old one
  Var GetOrAllocParam(const Expr& expr, const Type& type) {
    // run linear scan as most fused groups contain only a few inputs.
    for (size_t i = 0; i < arguments.size(); ++i) {
      if (expr.same_as(arguments[i])) return params[i];
    // create a new parameter.
    std::ostringstream os;
    os << "p" << params.size();
    auto var = Var(os.str(), type);
    return var;

Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
                            GraphPartitioner::Group* current_group) {
  Array<Expr> new_args;
  for (auto arg : args) {
    auto* arg_group =>FindRoot();
    auto type = arg->checked_type();
    Expr new_arg = this->Mutate(arg);
    if (current_group != arg_group) {
      if (!link_params_ ||<ConstantNode>() == nullptr) {
        Var param = ginfo_[current_group].GetOrAllocParam(new_arg, type);
      } else {
    } else {
  return new_args;

Expr MakeNewFunction(GraphPartitioner::Group* group, Type ret_type, Expr body) {
  const GroupInfo& ginfo = ginfo_[group];
  auto func = Function(ginfo.params, body, ret_type, {});
  func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
  return Call(func, ginfo.arguments, Attrs());

// Transform calls.
Expr Rewrite_(const CallNode* call, const Expr& post) {
  if (call-><OpNode>()) {
    // 找到其所属group
    auto* ret_group =>FindRoot();
    // 构造输入args列表
    Array<Expr> new_args = GetNewArguments(call->args, ret_group);
    // 构造CallNode节点
    auto new_call = Call(call->op, new_args, call->attrs, call->type_args, call->span);

    if (ret_group->root_ref == call) {
      // This is the root of the group
      // create the new call node.
      // 构造FunctionNode节点和对应的CallNode节点
      return MakeNewFunction(ret_group, call->checked_type(), new_call);
    } else {
      // This is an intermediate node of a fused function
      // simply return the new call.
      return std::move(new_call);
  } else {
    return ExprMutator::VisitExpr_(call);


