TVM,Relay,Pass

TVM,Relay,Pass

Relay介绍

主要结合TVM的文档(https://tvm.apache.org/docs/dev/relay_intro.html),介绍一下NNVM的第二代Relay。Relay的设计目标有以下几点:

支持传统的数据流(DataFlow)风格编程。支持functional-style scoping,并融合了编程语言领域的一些知识,带了一些新的特性(支持Let表达式,支持递归等等)支持数据流风格和函数式风格混合编程。

使用Relay建立一个计算图

传统的深度学习框架使用计算图作为的中间表示。计算图(或数据流图)是代表计算过程的有向无环图(DAG)。尽管由于缺少控制流,数据流图在计算能力方面受到限制,但简单性使其易于实现自动微分,并针对异构执行环境进行编译(例如,在专用硬件上执行计算图的某些部分,即子图)。

 

 

 使用Relay构建一个简单的计算图示例代码,对应的文本形式和AST抽象语法树,可以使用Relay来构建一个计算(DataFlow)图。具体来说,上面的代码显示了如何构造一个简单的两个节点的计算图,可以发现这个示例的代码和现有的Garph IR如NNVMv1没有太大区别,唯一的区别是在术语方面:

现有框架通常使用图和子图Relay使用函数,例如 – fn(%x),表示图每个数据流节点,都是Relay中的一个CallNode。通过Relay的Python DSL,可以快速构建计算图。上面的代码需要注意,这里显示构造了一个Add节点,两个输入都指向%1。当一个深度学习框架。对上面的计算图进行推理时,将会按照拓扑序进行计算,并且%1只会被计算一次。虽然这个事实对于深度学习框架的开发者,一件很自然的事情,但这或许会使得只关心算法的研究员困惑。如果实现一个简单的vistor打印结果,将结果视为嵌套的Call表达式,将是log(%x) + log(%x)。

当DAG中存在共享节点时,这种歧义是由程序语义的解释不同引起的。在正常的函数式编程IR中,嵌套表达式被视为表达式树,没有考虑%1,实际上在%2中被重用了2次的事实。

Relay IR注意到了这个区别。其实深度学习框架用户,经常使用这种方式构建计算图,其中经常发生DAG节点重用。然后以文本格式打印Relay程序时,每行打印一个CallNode,并为每个CallNode分配一个临时ID(%1, %2),以便可以在程序的后续部分中引用每个公共节点。

Module:支持多个函数(Graphs)

上面介绍了如何构建一个数据流图为一个函数。然后一个很自然的问题是可以做到构建多个函数并相互调用吗?Relay允许将多个函数组合在一个Module中,下面的代码展示了一个函数调用另外一个函数的例子。

def @muladd(%x, %y, %z) { %1 = mul(%x, %y) %2 = add(%1, %z) %2}def @myfunc(%x) { %1 = @muladd(%x, 1, 2) %2 = @muladd(%1, 2, 3) %2}

Module可以被看作Map<GlobalVar, Function>,GlobalVar仅仅是一个表示函数名的ID,上面的程序中GlobalVar是@muladd和@myfunc。当一个CallNode调用另外一个函数时,相应的GlobalVar被存在CallNode的OP中。包含了一个间接的等级关系---需要使用相应的GlobalVar,从Module中查找调用函数的主体。也可以直接将引用的函数存储为CallNode中的OP。为什么需要引入GlobalVar呢?主要原因是为了解耦定义和声明,并支持了函数的递归和延迟声明。

def @myfunc(%x) { %1 = equal(%x, 1)if (%1) { %x } else { %2 = sub(%x, 1) %3 = @myfunc(%2) %4 = add(%3, %3) %4 }}在上面的例子中,@myfunc递归调用。使用GlobalVar @myfunc表示函数,避免了数据结构中的循环依赖性。至此,已经介绍完了Relay中的基本概念。相比NNVM,Relay在如下方面进行了改进:

有文本形式中间表示,便于开发和 debug支持子图函数、联合模块,便于联合优化前端用户友好,便于调优0x2.3 Let Binding and Scopes

至此,已经介绍了如何用深度学习框架中的旧方法,构建计算图。这一节将讨论一个Relay的一个新的构造-let bindings。

Let binding被每一种高级的编程语言应用。在Relay中,一个拥有三个字段Let(var, value, body)的数据结构。计算一个Let表达式时,首先计算value部分,然后将其绑定到var,最后在body表达式中返回计算结果。

可以使用一系列的Let绑定,构造一个逻辑上等效于数据流程序的程序,下面的代码示例显示了这个用法:

 

 

 Let表达式构造和数据流程序等价的,计算图嵌套的Let Binding,称作A-normal形式,作为函数式编程语言中的常用IR。通过上面的图,可以发现虽然这两个程序的语义完全等价,文本表示也一样(除了A-norm形式有let的前缀),但AST抽象语法树却不一样。

由于程序的优化,使用了这些AST数据结构进行了变换,这两种不同的结构,影响到最终编译器生成的代码。比如,想要检测add(log(x), y)这个模式。在数据流程序中,可以首先进入add节点,然后直接检查第一个参数是不是log。在A-form的程序中,不能直接检查任何东西,因为add节点的输入是%v1-需要维护一个映射表,将变量和绑定的值进行映射,然后查表才知道%v1代表的是log。

为什么可能需要Let Binding

Let Binding的一种关键用法,可以指定计算的scope。看一下下面这个没有使用Let Binding的例子:

 

 

 没有使用Let Binding编程的一个例子,当尝试在该在哪里计算%1节点时,问题就来了。特别的是,虽然文本格式似乎建议,应该在if的scope之外,计算节点%1,但AST却不建议这样做。实际上数据流图,永远不会定义计算scope,这在语义上产生了一些歧义。

当有闭包时,这种歧义更加有趣,考虑下面的程序,该程序返回一个闭包。不知道在哪里计算%1,可以在闭包的内部和外部。

fn (%x) { %1 = log(%x) %2 = fn(%y) { add(%y, %1) } %2}Let Binding解决了这些问题,因为值的计算发生在let节点上。在这两个程序中,如果将%1 = log(%x)改成let %v1 = log(%x),将计算位置明确指定为if scope和闭包之外。Let Binding为计算端提供了更精确的范围,在生成后端代码时会很有用(因为这种范围在IR中)。

另一方面,没有指定计算scope的数据流形式,也有其自身的优势,不需要担心在生成代码时,将let放到哪里。数据流格式还为后面决定将计算节点放到哪里的Passes,提供了更大的自由度。因此,在优化的初始阶段,如果发现数据流形式,还是挺方便的,那么,使用数据流图的编码方法,可能不是一个坏主意。目前在Relay中也实现了很多针对数据流图的优化方式。

但是,当将IR lower到实际的运行时程序时,需要精确的计算scope。特别是当使用子函数和闭包时,要明确指定计算scope,应在哪里发生。在后期执行特定的优化中,可以使用Let Binding来解决此问题。

对IR转换的影响

希望到目前为止,已经熟悉两种表示形式。大多数函数式编程语言都以A-normal形式进行分析,分析人员无需注意表达式是DAG。

Relay选择同时支持数据流形式和Let Binding。TVM相信让框架开发者选择熟悉的表达形式很重要。但是这确实对写通用的Passes产生了一些影响。这里还没介绍Passes,对Passes理解不深,没有使用过Let表达式来构建网络,就不继续介绍具体有哪些影响了。

详细内容可以参考:https://tvm.apache.org/docs/dev/relay_intro.html#let-binding-and-scopes

基于Relay构建一个自定义的神经网络示例

基于Relay的接口定义一个Conv+BN+ReLU的小网络,展示一下Relay接口应该如何使用,这里TVM版本是0.8.0.dev,代码如下:

#coding=utf-8import tvmfrom tvm import relayimport numpy as npfrom tvm.contrib import graph_executor# 构造BNdefbatch_norm(data, gamma=None, beta=None, moving_mean=None, moving_var=None, **kwargs): name = kwargs.get("name") kwargs.pop("name")ifnot gamma: gamma = relay.var(name + "_gamma")ifnot beta: beta = relay.var(name + "_beta")ifnot moving_mean: moving_mean = relay.var(name + "_moving_mean")ifnot moving_var: moving_var = relay.var(name + "_moving_var")return relay.nn.batch_norm(data, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, **kwargs)[0]# 构造卷积defconv2d(data, weight=None, **kwargs): name = kwargs.get("name") kwargs.pop("name")ifnot weight: weight = relay.var(name + "_weight")return relay.nn.conv2d(data, weight, **kwargs)# 构造卷积+BN+ReLU的simpleNetdefsimplenet(data, name, channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), epsilon=1e-5): conv = conv2d( data=data, channels=channels, kernel_size=kernel_size, strides=strides, padding=padding, data_layout='NCHW', name=name+'_conv') bn = batch_norm(data=conv, epsilon=epsilon, name=name + '_bn') act = relay.nn.relu(data=bn)return actdata_shape = (1, 3, 224, 224)kernel_shape = (32, 3, 3, 3)dtype = "float32"data = relay.var("data", shape=data_shape, dtype=dtype)act = simplenet(data, "graph", 32, strides=(2, 2))func = relay.Function(relay.analysis.free_vars(act), act)print(func)np_data = np.random.uniform(-1, 1, (1, 3, 224, 224))params = {"graph_conv_weight": tvm.nd.array(np.random.uniform(-1, 1, (32, 3, 3, 3)).astype(dtype)),"graph_bn_gamma": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),"graph_bn_beta": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),"graph_bn_moving_mean": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),"graph_bn_moving_var": tvm.nd.array(np.random.uniform(-1, 1, (32)).astype(dtype)),}with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, "llvm", params=params)dev = tvm.cpu(0)dtype = "float32"m = graph_executor.GraphModule(lib["default"](dev))# set inputsm.set_input("data", tvm.nd.array(np_data.astype(dtype)))# executem.run()# get outputstvm_output = m.get_output(0)

就是一个很常规的过程,创建Relay Function,然后将所有的OP的权重信息用params这个字典存起来,注意这里的权重信息是随机初始化的。在编译Relay IR之前可以先看一下优化前的IR长什么样:

fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var) { %0 = nn.conv2d(%data, %graph_conv_weight, strides=[2, 2], padding=[1, 1, 1, 1], channels=32, kernel_size=[3, 3]); %1 = nn.batch_norm(%0, %graph_bn_gamma, %graph_bn_beta, %graph_bn_moving_mean, %graph_bn_moving_var); %2 = %1.0; nn.relu(%2)}符合第二节介绍的规则,Relay IR时一个函数。

初识Pass

上面构造simplenet的代码中,relay.build外部包了一层tvm.transform.PassContext,如下:

with tvm.transform.PassContext(opt_level=3): lib = relay.build(func, "llvm", params=params)实际上tvm.transform.PassContext这个接口就定义了Pass,如文档所示:

 

 

 tvm.transform.PassContext用来控制对relay IR使用哪些Pass进行优化,Pass是TVM中基于Relay IR进行的一系列优化,类似于onnx-simplifier里面用到的onnxoptimizer,可以简化计算图,去除一些冗余的算子,提高模型的推理效率。TVM将所有的pass都抽象到了tvm/include/tvm/ir/transform.h这个文件中,主要包含PassContext,PassInfo,Pass,以及Sequential。

这里的PassContext是上面Python接口对应的C++实现,包含了Pass执行依赖的一些参数,如优化level,依赖特定Pass,以及设置不使用某种指定Pass等。PassInfo是用来记录Pass信息的类,包含Pass的opy_level,name,以及当前Pass需要哪些前置Pass。而Pass这个类,就执行pass的主体,这是一个基类,每种Pass具体的C++代码实现在tvm/src/relay/transforms中,都会继承Pass这个基类。最后,Sequential是一个container,装载所有Pass。

需要说明一下,不是所有的Pass都定义在tvm/src/relay/transforms,比如下面的第一个例子,就在tvm/src/relay/backend/vm文件夹里。接下来将几个Pass的例子,到底对Relay IR做了什么?

RemoveUnusedFunctions首先来看一下定义在tvm/src/relay/backend/vm/removed_unused_funcs.cc这里的RemoveUnusedFunctions 这个pass,核心的代码实现如下:

voidVisitExpr_(const FunctionNode* func_node)final{auto func = GetRef<Function>(func_node);if (visiting_.find(func) == visiting_.end()) { visiting_.insert(func);for (auto param : func_node->params) { ExprVisitor::VisitExpr(param); } ExprVisitor::VisitExpr(func_node->body); } }IRModule RemoveUnusedFunctions(const IRModule& module, Array<runtime::String> entry_funcs){std::unordered_set<std::string> called_funcs{};for (auto entry : entry_funcs) {auto funcs = CallTracer(module).Trace(entry); called_funcs.insert(funcs.cbegin(), funcs.cend()); }auto existing_functions = module->functions;for (auto f : existing_functions) {auto it = called_funcs.find(f.first->name_hint);if (it == called_funcs.end()) {module->Remove(f.first); } }returnmodule;}

这个pass就是去除Relay IR中的冗余节点,VisitExpr_这个函数就是完成了一个图的遍历,然后把没有遍历到的节点删掉。删除发生在RemoveUnusedFunctions这个函数中。

ToBasicBlockNormalForm这个Pass实现在tvm/src/relay/transforms/to_basic_block_normal_form.cc,代码实现如下:

Expr ToBasicBlockNormalFormAux(const Expr& e){// calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e);/* The scope of the whole expr is global. * The scope of any subexpr, is the lowest common ancestor of all incoming edge. * We also record the set of expressions whose scope is lifted. */std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);return Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);}IRModule ToBasicBlockNormalForm(const IRModule& mod){ DLOG(INFO) << "ToBBlock:" << std::endl << mod; tvm::Map<GlobalVar, Function> updates;auto funcs = mod->functions;for (constauto& it : funcs) { ICHECK_EQ(FreeVars(it.second).size(), 0) << "Expected no free variables";if (constauto* n = it.second.as<FunctionNode>()) {if (n->GetAttr<String>(attr::kCompiler).defined()) continue; } Expr ret = TransformF([&](const Expr& e) { return ToBasicBlockNormalFormAux(e); }, it.second); updates.Set(it.first, Downcast<Function>(ret)); }for (auto pair : updates) { mod->Add(pair.first, pair.second, true); } DLOG(INFO) << "ToBBlock: transformed" << std::endl << mod;return mod;}boolBasicBlockNormalFormCheck(const Expr& e){// calculate all the dependency between nodes. support::Arena arena; DependencyGraph dg = DependencyGraph::Create(&arena, e);std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);for (auto expr : scopes.second) { LOG(FATAL) << "The expression below violates the basic block normal form in that " << "its scope should be lifted:\n" << expr; }return scopes.second.size() == 0;}ToBasicBlockNormalForm

这个函数通过遍历Relay IR中的function,将每个function转换为基本块形式(即ToBasicBlockNormalFormAux这个函数),ToBasicBlockNormalFormAux这个函数分成以下几个部分:

调用DependencyGraph dg = DependencyGraph::Create(&arena, e)创建一个DependencyGraph,这个数据结构是一个表达式相互依赖的图结构。通过std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg)计算每个节点的scope,这个scope可以简单理解为由跳转指令如Ifnode,FunctionNode,LetNode等隔开的那些子图,因为一旦碰到这些节点在上面通过Relay Function创建DependencyGraph就会为这种节点分配一个new_scope标志。然后CalcScope这个函数具体做了哪些事情,需要跟进去看一下:std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg){ NodeScopeMap expr_scope; ExprSet lifted_exprs;std::unordered_map<DependencyGraph::Node*, Expr> node_to_expr;// 首先让每个节点都属于一个单独的scopefor (auto expr_node : dg.expr_node) { node_to_expr[expr_node.second] = expr_node.first; }bool global_scope_used = false; Scope global_scope = std::make_shared<ScopeNode>();// 使用LCA算法来更新每个节点的真正scopefor (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it;auto iit = n->parents.head; Scope s;if (iit == nullptr) { ICHECK(!global_scope_used); s = global_scope; global_scope_used = true; } else { s = expr_scope.at(iit->value);constauto original_s = s; iit = iit->next;for (; iit != nullptr; iit = iit->next) { s = LCA(s, expr_scope.at(iit->value)); }if (s != original_s && node_to_expr.find(n) != node_to_expr.end()) {// filter out exprs whose scope do not matter Expr expr = node_to_expr[n];if (!expr.as<OpNode>()) { lifted_exprs.insert(expr); } } }if (n->new_scope) {auto child_scope = std::make_shared<ScopeNode>(s); expr_scope.insert({n, child_scope}); } else { expr_scope.insert({n, s}); } } ICHECK(global_scope_used);returnstd::make_pair(expr_scope, lifted_exprs);}

这个函数首先让每个节点都属于一个单独的scope,然后使用LCA算法来更新每个节点的真正scope。这里简单介绍一下LCA算法以及这里具体是如何求取每个节点的scope的。

最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面,离根最远的那个。为了方便,记某点集 的最近公共祖先为 或 。LCA有以下性质,引自OI-wiki:

 

 

 其实不看这个性质也没关系,了解LCA,可以求图中两个节点的最近公共祖先即可。然后CalcScope这个函数的具体思路,先将每个节点初始化为一个单独的scope,然后按照后DFS序遍历这些节点,对于每一个遍历到的节点(这里记作n),看一下它的父亲节点iit是否存在,如果不存在则说明当前节点是根节点,scope应该为global_scope。如果iit存在,那么遍历iit的子节点,看一下这些节点的scope的LCA表达式,如果这个通过LCA求出来的表达式和iit节点的表达式完全相同,说明这个子图和当前节点是属于同一个scope的,否则就将当前节点插入到lifted_exprs,lifted_exprs是一个集合用来保存这个DependencyGraph里面的那些跳转指令节点,这也是为什么上面再插入节点到lifted_exprs之前,需要判断一下这个节点的类型是否为OpNode。另外如果当前枚举的节点有new_scope标志,说明当前节点属于一个新的scope,需要为当前节点分配新的类型为ScopeNode的一个智能指针。

通过上面的算法,DependencyGraph中的节点和scope节点的关系就被映射到了一个map中,并且scope节点也被建立起了一个树结构。最后调用这个Fill::ToBasicBlockNormalForm(e, dg, &scopes.first, &scopes.second);来创建一个Fill类,这个类包含了DependencyGraph以及scope相关的信息,通过ToBasicBlockNormalForm成员函数实现基本块转换。实现在tvm/src/relay/transforms/to_a_normal_form.cc这个文件中,知乎对这个Pass也做了解释,这里引用一下:

它(ToBasicBlockNormalForm)的基本逻辑通过VisitExpr函数遍历dependency节点,将具有相同scope的节点压入到同一个let_list中。Let_list文档中是这样解释的:

/*! * \file let_list.h * \brief LetList record let binding and insert let expression implicitly. * using it, one can treat AST as value instead of expression, * and pass them around freely without fear of AST explosion (or effect duplication). * for example, if one write 'b = a + a; c = b + b; d = c + c', the AST will contain 8 'a'. * if one instead write 'b = ll.Push(a + a); c = ll.Push(b + b); d = ll.Get(c + c);', * the AST will contain 2 'a', as b and c are now variables.

Let_list使得抽象语法树简洁化,不会因为变量的复制导致树的爆炸。具有相同的scope的expr被约束到相同的let_list中,用一个var来表达,这样就将表达式转化为var的形式。一个var也就对应了一个基本块。

EliminateCommonSubexpr最后再看一个消除公共子表达式的Pass,所谓公共子表达式指的就是具有相同的OP类型以及相同的参数,参数的顺序都是完全相同的,这些表达式就可以合成一个公共子表达式。举个例子:

a = b + cd = b + c

可以看到这两个表达式时完全一致的,经过这个Pass之后,计算图就会消除其中一个表达式。代码实现在:tvm/src/relay/transforms/eliminate_common_subexpr.cc。这里定义了一个CommonSubexprEliminator类,这个类重载了两个Rewrite_函数,对expr进行遍历和重写。代码实现如下:

Expr Rewrite_(const CallNode* call, const Expr& post)final{staticauto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful"); Expr new_expr = post;const CallNode* new_call = new_expr.as<CallNode>(); ICHECK(new_call);const OpNode* op = new_call->op.as<OpNode>(); StructuralEqual attrs_equal;if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {return new_expr; }if (fskip_ != nullptr && fskip_(new_expr)) {return new_expr; }auto it = expr_map_.find(new_call->op);if (it != expr_map_.end()) {for (const Expr& candidate_expr : it->second) {if (const CallNode* candidate = candidate_expr.as<CallNode>()) {bool is_equivalent = true;// attrs匹配if (!attrs_equal(new_call->attrs, candidate->attrs)) {continue; }// args匹配for (size_t i = 0; i < new_call->args.size(); i++) {if (!new_call->args[i].same_as(candidate->args[i]) && !IsEqualScalar(new_call->args[i], candidate->args[i])) { is_equivalent = false;break; } }if (!is_equivalent) continue;return GetRef<Call>(candidate); } } } expr_map_[new_call->op].push_back(new_expr);return new_expr; }可以看到大概的思路就是利用expr_map_这个std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;

映射遍历过的具有相同op的expr,然后每次碰到相同op的表达式,都会对已经记录的expr进行匹配,匹配不仅包含OP的attrs属性,还包含参数列表,如果完全一样,说明这两个表达式就是公共表达式,就不返回新的表达式。这样就可以去掉Relay Function中的公共表达式了。

到这里可能还不是特别清楚最开始加载的那个simplenet的Relay Function,经过一些Pass之后,具体变成什么样,其实目前也还没搞清楚这个问题,这个问题应该就需要留到后面再解答了。

小结

本文介绍了一下TVM的Relay,介绍了如何基于Relay构建一个Conv+BN+ReLU的小网络,介绍了一下TVM中的Pass的工作机制,详细的介绍了RemoveUnusedFunctions,ToBasicBlockNormalForm,EliminateCommonSubexpr三种Pass。其中Relay部分的详细介绍大部分引用自官方文档:https://tvm.apache.org/docs/tutorials/get_started/introduction.html。

0x6. 参考资料

https://zhuanlan.zhihu.com/p/358437531https://zhuanlan.zhihu.com/p/91283238https://tvm.apache.org/docs/tutorials/get_started/introduction.html

 

https://baijiahao.baidu.com/s?id=1700872402469787364&wfr=spider&for=pc

 

posted @ 2021-09-17 06:13  吴建明wujianming  阅读(755)  评论(0编辑  收藏  举报