将编译器pass添加到Relay

将编译器pass添加到Relay

编译器pass是扩展Relay功能集和对Relay程序执行优化的主要接口。通过编写编译器pass,可以修改AST或收集有关AST的信息,具体取决于目标。事实上,Relay的一些最重要的内置功能(如autodiff和类型推断),只不过是“标准”编译器pass。

在高层次上,写pass有两个关键组成部分:

创建一个或多个遍历程序的C++类

将遍历实现及元数据包装在pass manager API中,以便可以与pass基础结构完整交互。

首先,将概述编写编译器pass的关键机制。然后,将介绍一个Relay中常量折叠pass的具体示例。

AST遍历器

用于遍历Relay程序的基类是ExprFunctor。提供的公共接口是一个VisitExpr方法,接受一个表达式和零个或多个参数,返回某种类型的实例。扩展此类时,可以通过为每种类型的表达式重写VisitExpr_ f的实现,定义AST遍历模式。

VisitExpr和VisitExpr_间的关系与调度有关。每个VisitExpr_定义都针对特定类型的表达式,但不总是知道要访问的节点类型。为了解决这个问题,ExprFunctor提供了一个VisitExpr函数,该函数从给定的表达式路由到处理VisitExpr_案例。尽管C++已经提供了动态调度,ExpPrimor还是定义了VisteExPR使用的VTe表。通过定义vtable,可以更好地控制调度。例如,如果想定义一个PrintVisitor遍历器,在每次访问前打印“Here”,可以覆盖VisitExpr:

void PrintVisitor::VisitExpr(const Expr& expr) {

  std::cout << "Here" << std::endl;

  ExprFunctor::VisitExpr(expr);

}

ExprFunctor本身是一个非常通用的类,这就是为什么经常会扩展ExprVisitor或ExprMutator。这些类扩展了ExprFunctor,提供了VisitExpr_的默认实现,该实现获取每种表达式类型的公共遍历模式。拥有这些默认实现,不同行为的表达式类型提供覆盖实现。在下面的介绍中,将单独描述每个子类。

ExprVisitor

ExprVisitor用于不修改程序,执行程序分析和收集信息的过程。在这个类中,VisitExpr和私有对应项不返回任何内容。此类提供的VisitExpr_实现,只需访问表达式的所有字段即可。IfNode的默认实现如下所示。

void ExprVisitor::VisitExpr_(const IfNode* op) {
  this->VisitExpr(op->cond);
  this->VisitExpr(op->true_branch);
  this->VisitExpr(op->false_branch);
}

在这里调用的是VisitExpr,不是VisitExpr,可以使用vtable in ExprFunctor进行路由。

现在,如果想编写一个类调用检查器,检查程序中是否出现任何函数调用,只需要扩展ExprVisitor,定义以下VisitExpr_方法:

void VisitExpr_(const CallNode* n) final {
  result_ = true;
}

其中result_是一个字段。在这种情况下,不需要在CallNode的字段上进一步递归,因为result_已经为true,原始表达式包含一个调用。为了使visitor可用,将提供以下公共方法:

bool Check(const Expr& expr) final {
  result_ = false;
  VisitExpr(expr);
  return result_;
}

这就是所需要的。在调用顶级递归前,定义一个公共接口,执行一些bookkeeping记录是非常常见的。当然,可以通过创建一个独立的pass,进一步包装API,该pass创建一个CallChecker实例调用Check,只花了很少的努力就实现了目标。

Expression Mutators

ExprMutator用于以某种方式转换程序的pass。使用该类,VisitExpr及私有对应项返回Expr。此类提供的默认VisitExpr_,实现访问表达式的所有字段,这些字段都是表达式,将这些字段设置为访问结果。TupleGetItemNode的默认实现如下所示。

Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
  auto t = this->Mutate(g->tuple);
  if (g->tuple == t) {
    return GetRef<Expr>(g);
  } else {
    return TupleGetItem(t, g->index);
  }
}

这里有几件事需要注意。首先,Mutate是ExprMutator中VisitExpr的别名。其次,如果Mutate调用修改了tuple字段,只返回一个新节点。这种更新方法称为功能更新,这样做可以避免不必要的分配。

ExprMutator的一个特性是ExprVisitor没有的,一个用于缓存结果的内置备注字段。ExprMutator有一个memoizer是有道理的,知道正在缓存哪些类型的结果(即Expr),ExprVisitor的访问方法不返回任何内容。通常,当想要将结果缓存在ExprVisitor的子类中时,需要定义缓存。

现在,如果想编写一个类IfCollapser,用真正分支替换每个if语句,将覆盖IfNode的VisitExpr_:

Expr ExprMutator::VisitExpr_(const IfNode* op) {
  return this->Mutate(op->true_branch);
}

返回的表达式不一定是IfNode,因为返回类型是Expr。现在,创建公共接口:

Expr CollapseIfs(const Expr& expr) final {
  return this->Mutate(expr);
}

有了这个mutator,不需要做任何记录,但仍然希望遵循使用描述性方法,作为接口的惯例。

示例:常量折叠

为了更好地理解编写pass,将以常量折叠pass(见src/relay/transforms/fold_constant.cc)为指导,因为是一个相对简单的过程,包含了两种类型的遍历。

常量折叠涉及计算程序中,只涉及常量值的表达式,然后用计算结果替换这些表达式。此pass的目标是预先加载所有可以进行的计算。为了实现这一点,常量折叠pass使用访客(ConstantChecker)和变异子(ConstantFolder)。

ConstantChecker Visitor

此访问者用于检查表达式是否为常量。在Relay中,如果表达式是常量节点或只有常量字段的元组节点,将定义为常量。

使用一个memo_字段,从节点映射是否为常量,缓存这些结果。以下是ConstantChecker中的VisitExpr_定义。

void VisitExpr_(const ConstantNode* n) final {
  memo_[GetRef<Constant>(n)] = true;
}
 
void VisitExpr_(const TupleNode* n) final {
  bool result = true;
  for (const auto& field : n->fields) {
    if (!Check(field)) {
      result = false;
      break;
    }
  }
  memo_[GetRef<Tuple>(n)] = result;
}

用于协调这些定义的记录是一个检查方法,返回给定表达式是否被视为常量。

bool Check(const Expr& expr) {
  const auto it = memo_.find(expr);
  if (it != memo_.end())
    return it->second;
  VisitExpr(expr);
  return memo_[expr];
}

不会为遇到的每个节点修改memo_;相反,只在遇到的节点可能是常量时修改memo_。然后,当memo_不包含expr时,依赖于默认值为false。

ConstantFolder Mutator常量折叠变异体

该mutator变异器执行大部分常量折叠pass,在内部使用ConstantChecker。在Relay中,常量折叠涉及三种节点类型:LetNode、TupleItemGetNode和CallNode。在下面的段落中,将解释pass中每个角色的作用。

Expr VisitExpr_(const LetNode* op) final {
  Expr value = this->Mutate(op->value);
  if (value.as<ConstantNode>()) {
    memo_[op->var] = value;
    return this->Mutate(op->body);
  } else {
    Var var = Downcast<Var>(this->Mutate(op->var));
    Expr body = this->Mutate(op->body);
    if (var.same_as(op->var) &&
        value.same_as(op->value) &&
        body.same_as(op->body)) {
      return GetRef<Expr>(op);
    } else {
      return Let(var, value, body);
    }
  }
}

在LetNode的情况下,首先尝试对表达式中绑定的值进行常量折叠。填充memo_,返回访问主体的结果,将绑定值传播到主体中的使用站点。如果不能将绑定值常量化,将模拟默认实现。

Expr VisitExpr_(const TupleGetItemNode* op) final {
  Expr res = ExprMutator::VisitExpr_(op);
  op = res.as<TupleGetItemNode>();
  if (const auto* tuple = op->tuple.as<TupleNode>()) {
    return tuple->fields[op->index];
  } else {
    return res;
  }
}

在TupleItemGetNode的情况下,检查op->tuple字段是否是TupleNode。用op->index指向的元组字段替换元组get。需要检查的原因是op->tuple可能计算为一个tuple,本身不是tuple。

Expr VisitExpr_(const CallNode* call) final {
  static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
  Expr res = ExprMutator::VisitExpr_(call);
  call = res.as<CallNode>();
  // We don't constant fold function with zero arguments.
  // This is a heuristic that is useful.
  // For example it is harmful to fold ones(shape=(4, 5)).
  if (call->args.size() == 0) return res;
  const OpNode* op = call->op.as<OpNode>();
  if (op == nullptr) return res;
  // skip stateful ops.
  if (op_stateful.get(GetRef<Op>(op), false)) return res;
  bool all_const_args = true;
  for (Expr arg : call->args) {
    if (!checker_.Check(arg)) {
      all_const_args = false;
    }
  }
  if (all_const_args) {
    return ConstEvaluate(res);
  } else {
    return res;
  }
}

在CallNode的情况下,首先使用ExprMutator的VisitExpr_访问调用,将调用的所有字段折叠起来。使用ExprMutator::VisitExpr_uu而不是VisitExpr,因为希望绕过vtable(避免无限循环),使用ExprMutator提供的默认实现。然后,仅在所有参数都是常量时(使用ConstantChecker)计算调用。对调用求值会产生一个值,因此使用help方法ValueToExpr,将求值表达式放回AST中。

现在,为常量文件夹构造一个更方便的接口FoldConstant。FoldConstant是ConstantFolder类外的一个独立函数,接受一个表达式,在内部创建和使用ConstantFolder实例(完整定义可在src/relay/transforms/fold_constant.cc中找到)。

向pass管理器注册pass

参阅:ref:`pass infra`上的文档,了解有关此主题的更多详细信息。

编写AST遍历器后,可以使用以下代码,将pass注册为TVM API端点:

namespace transform {
 
Pass FoldConstant() {
  runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
    [=](Function f, Module m, PassContext pc) {
      return Downcast<Function>(FoldConstant(f));
  };
  return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
 
}  // namespace transform

如果将上述代码生成的Pass对象,提供给Pass基础设施,将确保将AST遍历应用于给定Relay模块中的每个函数,这是常量折叠Pass的预期行为(它应尽可能折叠所有常数)。

函数CreateFunctionPass允许注册pass的优化级别(在本例中为2),该级别可用于根据pass的通用工具、pass名称以及pass的任何依赖项,将pas分组。pass的依赖项以任何pass的列表的形式给出,这些pass的结果是运行当前pass所必需的。FoldConstant没有任何依赖项,但许多Relay pass确实依赖于类型信息,因此InferType是一个常见的依赖项;另一些可能依赖于程序,通过ToANormalForm pass处于A-normal形式。

注意,PassContext对象包含pass用于错误报告和配置选项的信息;FoldConstant不需要此信息,但其它pass可能会引用PassContext对象。

现在可以通过pass基础设施调用pass,不过最好为pass添加一个Python绑定,如下面的代码片段所示:

TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);

一旦以上述方式定义了Pass对象,就可以使用Pass基础设施的顺序构造调用,该构造获取一个Pass列表,按顺序应用于Relay模块,从而获得转换后的模块。例如,下面的代码将FoldConstant和ToANormalForm pass(一个接一个),应用于mod中的每个函数,获得一个新模块。

seq = transform.Sequential([
    relay.transform.FoldConstant(),
    relay.transform.ToANormalForm()
])
new_mod = seq(mod)

有关注册的更多详细信息,可以在TVM Runtime系统中找到,有关pass manager接口的更多信息可以在pass基础设施中找到。Relay的标准pass在include/tvm/Relay/transform.h中列出,在src/Relay/transforms/中实现。

 

参考链接:

https://tvm.apache.org/docs/dev/how_to/relay_add_pass.html

 

posted @ 2021-11-20 06:05  吴建明wujianming  阅读(179)  评论(0编辑  收藏  举报