将编译器pass添加到Relay
将编译器pass添加到Relay
编译器pass是扩展Relay功能集和对Relay程序执行优化的主要接口。通过编写编译器pass,可以修改AST或收集有关AST的信息,具体取决于目标。事实上,Relay的一些最重要的内置功能(如autodiff和类型推断),只不过是“标准”编译器pass。
在高层次上,写pass有两个关键组成部分:
创建一个或多个遍历程序的C++类
将遍历实现及元数据包装在pass manager API中,以便可以与pass基础结构完整交互。
首先,将概述编写编译器pass的关键机制。然后,将介绍一个Relay中常量折叠pass的具体示例。
AST遍历器
用于遍历Relay程序的基类是ExprFunctor。提供的公共接口是一个VisitExpr方法,接受一个表达式和零个或多个参数,返回某种类型的实例。扩展此类时,可以通过为每种类型的表达式重写VisitExpr_ f的实现,定义AST遍历模式。
VisitExpr和VisitExpr_间的关系与调度有关。每个VisitExpr_定义都针对特定类型的表达式,但不总是知道要访问的节点类型。为了解决这个问题,ExprFunctor提供了一个VisitExpr函数,该函数从给定的表达式路由到处理VisitExpr_案例。尽管C++已经提供了动态调度,ExpPrimor还是定义了VisteExPR使用的VTe表。通过定义vtable,可以更好地控制调度。例如,如果想定义一个PrintVisitor遍历器,在每次访问前打印“Here”,可以覆盖VisitExpr:
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
ExprFunctor本身是一个非常通用的类,这就是为什么经常会扩展ExprVisitor或ExprMutator。这些类扩展了ExprFunctor,提供了VisitExpr_的默认实现,该实现获取每种表达式类型的公共遍历模式。拥有这些默认实现,不同行为的表达式类型提供覆盖实现。在下面的介绍中,将单独描述每个子类。
ExprVisitor
ExprVisitor用于不修改程序,执行程序分析和收集信息的过程。在这个类中,VisitExpr和私有对应项不返回任何内容。此类提供的VisitExpr_实现,只需访问表达式的所有字段即可。IfNode的默认实现如下所示。
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
在这里调用的是VisitExpr,不是VisitExpr,可以使用vtable in ExprFunctor进行路由。
现在,如果想编写一个类调用检查器,检查程序中是否出现任何函数调用,只需要扩展ExprVisitor,定义以下VisitExpr_方法:
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
其中result_是一个字段。在这种情况下,不需要在CallNode的字段上进一步递归,因为result_已经为true,原始表达式包含一个调用。为了使visitor可用,将提供以下公共方法:
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
这就是所需要的。在调用顶级递归前,定义一个公共接口,执行一些bookkeeping记录是非常常见的。当然,可以通过创建一个独立的pass,进一步包装API,该pass创建一个CallChecker实例调用Check,只花了很少的努力就实现了目标。
Expression Mutators
ExprMutator用于以某种方式转换程序的pass。使用该类,VisitExpr及私有对应项返回Expr。此类提供的默认VisitExpr_,实现访问表达式的所有字段,这些字段都是表达式,将这些字段设置为访问结果。TupleGetItemNode的默认实现如下所示。
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItem(t, g->index);
}
}
这里有几件事需要注意。首先,Mutate是ExprMutator中VisitExpr的别名。其次,如果Mutate调用修改了tuple字段,只返回一个新节点。这种更新方法称为功能更新,这样做可以避免不必要的分配。
ExprMutator的一个特性是ExprVisitor没有的,一个用于缓存结果的内置备注字段。ExprMutator有一个memoizer是有道理的,知道正在缓存哪些类型的结果(即Expr),ExprVisitor的访问方法不返回任何内容。通常,当想要将结果缓存在ExprVisitor的子类中时,需要定义缓存。
现在,如果想编写一个类IfCollapser,用真正分支替换每个if语句,将覆盖IfNode的VisitExpr_:
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
返回的表达式不一定是IfNode,因为返回类型是Expr。现在,创建公共接口:
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
有了这个mutator,不需要做任何记录,但仍然希望遵循使用描述性方法,作为接口的惯例。
示例:常量折叠
为了更好地理解编写pass,将以常量折叠pass(见src/relay/transforms/fold_constant.cc)为指导,因为是一个相对简单的过程,包含了两种类型的遍历。
常量折叠涉及计算程序中,只涉及常量值的表达式,然后用计算结果替换这些表达式。此pass的目标是预先加载所有可以进行的计算。为了实现这一点,常量折叠pass使用访客(ConstantChecker)和变异子(ConstantFolder)。
ConstantChecker Visitor
此访问者用于检查表达式是否为常量。在Relay中,如果表达式是常量节点或只有常量字段的元组节点,将定义为常量。
使用一个memo_字段,从节点映射是否为常量,缓存这些结果。以下是ConstantChecker中的VisitExpr_定义。
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef<Constant>(n)] = true;
}
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
用于协调这些定义的记录是一个检查方法,返回给定表达式是否被视为常量。
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
不会为遇到的每个节点修改memo_;相反,只在遇到的节点可能是常量时修改memo_。然后,当memo_不包含expr时,依赖于默认值为false。
ConstantFolder Mutator常量折叠变异体
该mutator变异器执行大部分常量折叠pass,在内部使用ConstantChecker。在Relay中,常量折叠涉及三种节点类型:LetNode、TupleItemGetNode和CallNode。在下面的段落中,将解释pass中每个角色的作用。
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return Let(var, value, body);
}
}
}
在LetNode的情况下,首先尝试对表达式中绑定的值进行常量折叠。填充memo_,返回访问主体的结果,将绑定值传播到主体中的使用站点。如果不能将绑定值常量化,将模拟默认实现。
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
在TupleItemGetNode的情况下,检查op->tuple字段是否是TupleNode。用op->index指向的元组字段替换元组get。需要检查的原因是op->tuple可能计算为一个tuple,本身不是tuple。
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
在CallNode的情况下,首先使用ExprMutator的VisitExpr_访问调用,将调用的所有字段折叠起来。使用ExprMutator::VisitExpr_uu而不是VisitExpr,因为希望绕过vtable(避免无限循环),使用ExprMutator提供的默认实现。然后,仅在所有参数都是常量时(使用ConstantChecker)计算调用。对调用求值会产生一个值,因此使用help方法ValueToExpr,将求值表达式放回AST中。
现在,为常量文件夹构造一个更方便的接口FoldConstant。FoldConstant是ConstantFolder类外的一个独立函数,接受一个表达式,在内部创建和使用ConstantFolder实例(完整定义可在src/relay/transforms/fold_constant.cc中找到)。
向pass管理器注册pass
参阅:ref:`pass infra`上的文档,了解有关此主题的更多详细信息。
编写AST遍历器后,可以使用以下代码,将pass注册为TVM API端点:
namespace transform {
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
} // namespace transform
如果将上述代码生成的Pass对象,提供给Pass基础设施,将确保将AST遍历应用于给定Relay模块中的每个函数,这是常量折叠Pass的预期行为(它应尽可能折叠所有常数)。
函数CreateFunctionPass允许注册pass的优化级别(在本例中为2),该级别可用于根据pass的通用工具、pass名称以及pass的任何依赖项,将pas分组。pass的依赖项以任何pass的列表的形式给出,这些pass的结果是运行当前pass所必需的。FoldConstant没有任何依赖项,但许多Relay pass确实依赖于类型信息,因此InferType是一个常见的依赖项;另一些可能依赖于程序,通过ToANormalForm pass处于A-normal形式。
注意,PassContext对象包含pass用于错误报告和配置选项的信息;FoldConstant不需要此信息,但其它pass可能会引用PassContext对象。
现在可以通过pass基础设施调用pass,不过最好为pass添加一个Python绑定,如下面的代码片段所示:
TVM_REGISTER_GLOBAL("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
一旦以上述方式定义了Pass对象,就可以使用Pass基础设施的顺序构造调用,该构造获取一个Pass列表,按顺序应用于Relay模块,从而获得转换后的模块。例如,下面的代码将FoldConstant和ToANormalForm pass(一个接一个),应用于mod中的每个函数,获得一个新模块。
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
有关注册的更多详细信息,可以在TVM Runtime系统中找到,有关pass manager接口的更多信息可以在pass基础设施中找到。Relay的标准pass在include/tvm/Relay/transform.h中列出,在src/Relay/transforms/中实现。
参考链接:
https://tvm.apache.org/docs/dev/how_to/relay_add_pass.html