TVM Pass优化 -- 公共子表达式消除(Common Subexpr Elimination, CSE)

定义(What)

公共子表达式消除 就是如果表达式E的值已经计算的到了,并且自计算的到值后E的值就不再改变了,就说,表达式E在后续计算中是一个公共表达式。
简单说,该表达式上面已经执行过了,下面没必要再执行了
举个例子:

import tvm
from tvm import relay
from tvm.relay import transform

def run_opt_pass(expr, opt_pass):
    assert isinstance(opt_pass, tvm.transform.Pass)
    mod = tvm.IRModule.from_expr(expr)
    mod = opt_pass(mod)
    entry = mod["main"]
    return entry if isinstance(expr, relay.Function) else entry.body


def before():
    x = relay.var("x", shape=(1, 16))
    y1 = relay.nn.relu(x)
    y2 = relay.nn.relu(x)
    y1 = relay.add(y1, relay.const(1.0, "float32"))
    y2 = relay.add(y2, relay.const(1.0, "float32"))
    y = relay.add(y1, y2)
    f = relay.Function([x], y)
    return f

z = before()
print("before")
print(z)
z = run_opt_pass(z, transform.EliminateCommonSubexpr())
print("after")
print(z)

通过print(z)打印公共子表达式消除前IRModule对象内容,如下:
image
消除之后的IRModule对象内容如下:
image
可以发现Relay图中的y2 = relay.nn.relu(x)节点被清除
因为表达式y2 = relay.nn.relu(x)在前一个表达式y1 = relay.nn.relu(x)中已经计算过了,只需要用前面计算过的表达式结果代替即可

作用 (Why)

意义就很简单了,为了避免重新计算表达式E,浪费计算资源,影响运行效率

怎么做(How)

上面的例子可看到,公共子表达式消除主要调用的是relay.transform.EliminateCommonSubexpr()接口,这个接口是对已注册的公共子表达式消除pass的封装。可见路径:python/tvm/relay/transform/transform.py

def EliminateCommonSubexpr(fskip=None):
    """Eliminate common subexpressions.

    Parameters
    ----------
    fskip: Callable
        The callback function that decides whether an expression should be
        skipped.

    Returns
    -------
    ret : tvm.transform.Pass
        The registered pass that eliminates common subexpressions.
    """
    return _ffi_api.EliminateCommonSubexpr(fskip)

通过PackFunc机制,_ffi_api.EliminateCommonSubexpr接口最后会通过_LIB.TVMFuncGetGlobal函数获取到C++端注册的EliminateCommonSubexpr函数。
C++端EliminateCommonSubexpr注册代码如下:

Pass EliminateCommonSubexpr(PackedFunc fskip) {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        return Downcast<Function>(EliminateCommonSubexpr(f, fskip));
      };
  return CreateFunctionPass(pass_func, 3, "EliminateCommonSubexpr", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.EliminateCommonSubexpr")
    .set_body_typed(EliminateCommonSubexpr);

上述代码,CreateFunctionPass()函数作用是生成FunctionPass对象,FunctionPass工作在Relay模块中的每一个Relay函数对象上。

  • FunctionPass() 函数的第一个参数pass_func是TypedPackedFunc对象,真正的pass优化功能由该对象调用pass函数EliminateCommonSubexpr()完成。
  • 第二个参数是优化级别(当通过pass基础架构调用该pass时,会检查pass的优化级别,只有当该pass的优化级别不低于pass上下文配置中的优化级别时,才能启用执行该pass);
  • 第三个参数是函数pass名称;
  • 第四个参数是{} 中列出了公共子表达式消除pass依赖的其他pass,如InferType,因为需要类型信息,所以参数中列出了InferType pass名称

EliminateCommonSubexpr()函数的函数体是CommonSubexprEliminator()函数,它主要通过实现遍历Relay IR,完成Relay IR中的公共子表达式消除功能。

Relay IR遍历的C++实现类是ExprFunctor类的派生类,继承关系如下:
image

CommonSubexprEliminator()类通过重载Rewrite_()方法实现公共子表达式消除功能。该方法将处理过的表达式都存储在unordered_map变量expr _map_中。在每次通过ReWrite_方法处理当前表达式时,会先从expr_map_中查找是否有相同操作类型的已处理表达式,如果有,在判断当前表达式与已处理表达式的属性和参数是否相同,如果这些条件都满足,则返回满足条件的一处理表达式。

expr_map_定义如下:

  std::unordered_map<Expr, std::vector<Expr>, ObjectPtrHash, ObjectPtrEqual> expr_map_;

ReWrite_()方法(src/relay/transforms/eliminate_common_subexpr.cc)实现代码如下:

Expr Rewrite_(const CallNode* call, const Expr& post) final {
    static auto op_stateful = Op::GetAttrMap<TOpIsStateful>("TOpIsStateful");
    Expr new_expr = post;
    const CallNode* new_call = new_expr.as<CallNode>();
    ICHECK(new_call);
    const OpNode* op = new_call->op.as<OpNode>();
    StructuralEqual attrs_equal;

    if (new_call->args.size() == 0 || op == nullptr || op_stateful.get(GetRef<Op>(op), false)) {
      return new_expr;
    }
    if (fskip_ != nullptr && fskip_(new_expr)) {
      return new_expr;
    }

    auto it = expr_map_.find(new_call->op);
    if (it != expr_map_.end()) {
      for (const Expr& candidate_expr : it->second) {
        if (const CallNode* candidate = candidate_expr.as<CallNode>()) {
          bool is_equivalent = true;
          if (!attrs_equal(new_call->attrs, candidate->attrs)) {
            continue;
          }
          for (size_t i = 0; i < new_call->args.size(); i++) {
            if (!IsEquivalent(new_call->args[i], candidate->args[i])) {
              is_equivalent = false;
              break;
            }
          }
          if (!is_equivalent) continue;
          return GetRef<Call>(candidate);
        }
      }
    }
    expr_map_[new_call->op].push_back(new_expr);
    return new_expr;
  }

在python端调用时,通过CreateFunctionPass()函数返回FunctionPass对象,然后通过该对象调用算子,如上述例子中opt_pass(mod)
它会调用Pass类的__call__方法来调用算子

@tvm._ffi.register_object("transform.Pass")
class Pass(tvm.runtime.Object):
    """The base class of all passes. All methods here are just simple wrappers
    that are implemented in the backend. They are defined for users to
    conveniently interact with the base class.
    """

    @property
    def info(self):
        """Get the pass meta."""
        return _ffi_transform_api.Info(self)

    def __call__(self, mod):
        """Execute the pass. Note that for sequential pass, the dependency among
        different passes will be resolved in the backend.

        Parameters
        ----------
        mod : tvm.IRModule
            The module that a certain optimization is performed on.

        Returns
        -------
        mod : tvm.IRModule
            The updated module after applying this pass.
        """
        return _ffi_transform_api.RunPass(self, mod)

src/ir/transform.cc中ransform.RunPass注册代码如下:

TVM_REGISTER_GLOBAL("transform.RunPass").set_body_typed([](Pass pass, IRModule mod) {
  return pass(std::move(mod));
});

此处的pass就是通过CreateFunctionPass()创建的对象,此处会调用pass中operator()重载,最终会调到FunctionPassNode类中的operator()方法,该实现会调到CreateFunctionPass()时保存的真正公共子表达式消除的代码的实现pass_func

总体,该算子优化还算是比较简单

respect~
致敬

posted @ 2024-04-06 14:46  牛犁heart  阅读(70)  评论(0编辑  收藏  举报