MLIR通过.td定义新Dialect与新Op

MLIR通过.td定义新Dialect与新Op
MLIR:新建一个Dialect,通过.td定义新Dialect
MLIR 项目的核心是 Dialect,MLIR 自身就拥有例如linalg,tosa,affine 这些 Dialect。各种不同的 Dialect 使不同类型的优化或转换得以完成。
新建一个 Dialect的内容。通过.td定义新 Dialect “Hello”。
1复习
工具链、总览等等知识请自行翻看历史 MLIR 标签的相关文章
前文介绍了mlir-hello[1] 项目的目标就是使用自建的 Dialect 通过 MLIR 生态实现一个 hello world,具体做法为:
1. 创建 hello-opt 将原始 print.mlir (可以理解成 hello world 的 main.cpp)转换为 print.ll 文件
2. 使用 LLVM 的 lli 解释器直接运行 print.ll 文件
2HelloDialect.td
hello.print 作为一个 Op,显而易见,hello Dialect、print Op 都需要被定义。
本文来看看如何定义一个新的 Dialect 以及一个 Op。
通过声明式的 .td 文件以及 TableGen 工具可以便捷的生成相应的 C++ 代码。
代码来自 [mlir-hello]/include/Hello/HelloDialect.td,
#ifndef HELLO_DIALECT
#define HELLO_DIALECT

// 引入大基类 Dialect
include "mlir/IR/OpBase.td"

// 定义新 Hello_Dialect
def Hello_Dialect : Dialect {
// 定义名字空间 namespace,对应 C++ 的 getDialectNamespace 方法返回值
let name = "hello";
// 一行关于这个 Dialect 的介绍
let summary = "A hello out-of-tree MLIR dialect.";
// 更详细的关于这个 Dialect 的介绍
let description = [{
This dialect is minimal example to implement hello-world kind of sample code
for MLIR.
}];
// 产生一个返回名字空间名称的接口
let cppNamespace = "::hello";
// 该设置用于激活 materializeConstant 方法,这使得可以例如 Canonicalize 优化
let hasConstantMaterializer = 1;
}

// 定义一个 Op 作为后续其他具体 Op 的“基类”
class Hello_Op<string mnemonic, list<Trait> traits = []> :
Op<Hello_Dialect, mnemonic, traits>;

#endif // HELLO_DIALECT
3TableGen
来看看这个 .td 能生成什么样子的代码?
$MLIR_TBLGEN -gen-dialect-decls HelloDialect.td -I$LOCAL_MLIR/include >> HelloDialect.h
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Dialect Declarations *|
|* *|
|* Automatically generated file, do not edit! *|
|* *|
\*===----------------------------------------------------------------------===*/
// 名字空间
namespace hello {

class HelloDialect : public ::mlir::Dialect {
explicit HelloDialect(::mlir::MLIRContext *context);
// 用以注册 attributes, operations, types 等等
void initialize();
friend class ::mlir::MLIRContext;
public:
~HelloDialect() override;
// 定义的名字空间 name
static constexpr ::llvm::StringLiteral getDialectNamespace() {
return ::llvm::StringLiteral("hello");
}

/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
// 设置的 hasConstantMaterializer 位
::mlir::Operation *materializeConstant(::mlir::OpBuilder &builder,
::mlir::Attribute value,
::mlir::Type type,
::mlir::Location loc) override;
};
} // namespace hello
MLIR_DECLARE_EXPLICIT_TYPE_ID(::hello::HelloDialect)

对 mlir-hello 项目的源代码文件 HelloDialect.td 进行了学习,通过自定义的 .td 文件声明式的语法可以便捷的定义一个新的 Dialect。
MLIR:新建一个Dialect,通过.td定义新Op
Multi-Level Intermediate Representation(MLIR)是创建可重用、可扩展编译器基础设施的新途径。介绍一个简单的 MLIR Dialect。
MLIR 项目的核心是 Dialect,MLIR 自身就拥有例如linalg,tosa,affine 这些 Dialect。各种不同的 Dialect 使不同类型的优化或转换得以完成。
新建一个 Dialect的内容。通过.td定义新 “Hello_Op”。
1复习
工具链、总览等等知识请自行翻看历史 MLIR 标签的相关文章
介绍了mlir-hello项目的目标就是使用自建的 Dialect 通过 MLIR 生态实现一个 hello world,具体做法为:
1. 创建 hello-opt 将原始 print.mlir (可以理解成 hello world 的 main.cpp)转换为 print.ll 文件
2. 使用 LLVM 的 lli 解释器直接运行 print.ll 文件
2HelloOps.td
hello.print 作为一个 Op,显而易见,hello Dialect、print Op 都需要被定义。
本文来看看如何定义一个保存变量的ConstantOp和执行打印操作的PrintOp,也就是实际 MLIR 使用中的 hello.constant 和 hello.print。
通过声明式的 .td 文件以及 TableGen[2] 工具可以便捷的生成相应的 C++ 代码。
更详细的语法可以在 Operation Definition Specification (ODS)[3]找到。
代码来自 [mlir-hello]/include/Hello/HelloOps.td,
#ifndef HELLO_OPS
#define HELLO_OPS

include "HelloDialect.td"
// 包含 NoSideEffect 的 trait,不主动做某些消除优化
include "mlir/Interfaces/SideEffectInterfaces.td"

// 第一个 Op,用以转换输入为内部使用的 SSA 值
// 类似 Dialect 中定义的 class HelloOp (对象)
// 实际名字为 constant(CRTP)
def ConstantOp : Hello_Op<"constant", [Pure]> {
// 一行关于这个 Op 的介绍
let summary = "constant";
// 更详细的关于这个 Op 的介绍
let description = [{
Constant operation turns a literal into an SSA value. The data is attached
to the operation as an attribute. For example:

```mlir
%0 = "hello.constant"()
{ value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf64> }
: () -> tensor<2x3xf64>
```
}];

// 每个 Op 都会有的 builder 方法们 https://mlir.llvm.org/docs/OpDefinitions/#builder-methods
let builders = [
// 重载的 build 函数,定义参数。例如下面会生成类似
// `static void build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, mlir::DenseElementsAttr value);`
// 样的代码。前面的 ins 指示 dag-type
OpBuilder<(ins "mlir::DenseElementsAttr":$value), [{
build($_builder, $_state, value.getType(), value);
}]>,
OpBuilder<(ins "double":$value)>
];

// let parser = [{ return ::parseConstantOp(parser, result); }];
// 定义输入,类似上面的 builder。可以是 operands 或 attributes,这里是后者。前者意思是由其他 operation 产生的值
let arguments = (ins F64ElementsAttr:$value);
// 定义输出
let results = (outs F64Tensor);
}

// 第二个 Op,用以表明打印操作
def PrintOp : Hello_Op<"print", [Pure]> {
let summary = "print operation";
let description = [{
The "print" builtin operation prints a given input tensor, and produces
no results.
}];

// The print operation takes an input tensor to print.
let arguments = (ins AnyTypeOf<[F64Tensor, F64MemRef]>:$input);
// 手动写明这个 Op 的输出
let assemblyFormat = "$input attr-dict `:` type($input)";
}

#endif // HELLO_OPS
3TableGen
来看看这个 .td 能生成什么样子的代码?
$MLIR_TBLGEN -gen-op-decls include/Hello/HelloOps.td -I$LOCAL_MLIR/include -Iinclude/Hello >> HelloOps.decls.h
/*===- TableGen'erated file -------------------------------------*- C++ -*-===*\
|* *|
|* Op Declarations *|
|* *|
|* Automatically generated file, do not edit! *|
|* *|
\*===----------------------------------------------------------------------===*/

#if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES)
#undef GET_OP_FWD_DEFINES
namespace hello {
class ConstantOp;
} // namespace hello
namespace hello {
class PrintOp;
} // namespace hello
#endif

#ifdef GET_OP_CLASSES
#undef GET_OP_CLASSES


//===----------------------------------------------------------------------===//
// Local Utility Method Definitions
//===----------------------------------------------------------------------===//

namespace hello {

//===----------------------------------------------------------------------===//
// ::hello::ConstantOp declarations
//===----------------------------------------------------------------------===//

class ConstantOpAdaptor {
public:
ConstantOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});

ConstantOpAdaptor(ConstantOp op);

::mlir::ValueRange getOperands();
std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::ValueRange getODSOperands(unsigned index);
::mlir::DictionaryAttr getAttributes();
::mlir::DenseElementsAttr getValueAttr();
::mlir::DenseElementsAttr getValue();
::mlir::LogicalResult verify(::mlir::Location loc);
private:
::mlir::ValueRange odsOperands;
::mlir::DictionaryAttr odsAttrs;
::mlir::RegionRange odsRegions;
::llvm::Optional<::mlir::OperationName> odsOpName;
};
class ConstantOp : public ::mlir::Op<ConstantOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::TensorType>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::ZeroOperands, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
using Adaptor = ConstantOpAdaptor;
public:
static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
static ::llvm::StringRef attrNames[] = {::llvm::StringRef("value")};
return ::llvm::makeArrayRef(attrNames);
}

::mlir::StringAttr getValueAttrName() {
return getAttributeNameForIndex(0);
}

static ::mlir::StringAttr getValueAttrName(::mlir::OperationName name) {
return getAttributeNameForIndex(name, 0);
}

static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("hello.constant");
}

std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::Operation::operand_range getODSOperands(unsigned index);
std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
::mlir::Operation::result_range getODSResults(unsigned index);
::mlir::DenseElementsAttr getValueAttr();
::mlir::DenseElementsAttr getValue();
void setValueAttr(::mlir::DenseElementsAttr attr);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, mlir::DenseElementsAttr value);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, double value);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type resultType0, ::mlir::DenseElementsAttr value);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::DenseElementsAttr value);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
::mlir::LogicalResult verifyInvariantsImpl();
::mlir::LogicalResult verifyInvariants();
void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
private:
::mlir::StringAttr getAttributeNameForIndex(unsigned index) {
return getAttributeNameForIndex((*this)->getName(), index);
}

static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) {
assert(index < 1 && "invalid attribute index");
assert(name.getStringRef() == getOperationName() && "invalid operation name");
return name.getRegisteredInfo()->getAttributeNames()[index];
}

public:
};
} // namespace hello
MLIR_DECLARE_EXPLICIT_TYPE_ID(::hello::ConstantOp)

namespace hello {

//===----------------------------------------------------------------------===//
// ::hello::PrintOp declarations
//===----------------------------------------------------------------------===//

class PrintOpAdaptor {
public:
PrintOpAdaptor(::mlir::ValueRange values, ::mlir::DictionaryAttr attrs = nullptr, ::mlir::RegionRange regions = {});

PrintOpAdaptor(PrintOp op);

::mlir::ValueRange getOperands();
std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::ValueRange getODSOperands(unsigned index);
::mlir::Value getInput();
::mlir::DictionaryAttr getAttributes();
::mlir::LogicalResult verify(::mlir::Location loc);
private:
::mlir::ValueRange odsOperands;
::mlir::DictionaryAttr odsAttrs;
::mlir::RegionRange odsRegions;
::llvm::Optional<::mlir::OperationName> odsOpName;
};
class PrintOp : public ::mlir::Op<PrintOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::ZeroResults, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait> {
public:
using Op::Op;
using Op::print;
using Adaptor = PrintOpAdaptor;
public:
static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() {
return {};
}

static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("hello.print");
}

std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index);
::mlir::Operation::operand_range getODSOperands(unsigned index);
::mlir::Value getInput();
::mlir::MutableOperandRange getInputMutable();
std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index);
::mlir::Operation::result_range getODSResults(unsigned index);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value input);
static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value input);
static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
::mlir::LogicalResult verifyInvariantsImpl();
::mlir::LogicalResult verifyInvariants();
static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
void print(::mlir::OpAsmPrinter &_odsPrinter);
void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects);
public:
};
} // namespace hello
MLIR_DECLARE_EXPLICIT_TYPE_ID(::hello::PrintOp)
#endif // GET_OP_CLASSES
TLDR
这些只是 Op 的声明,还有定义,太多了就不放这里了。可以通过下面指令生成,
$MLIR_TBLGEN -gen-op-defs include/Hello/HelloOps.td -I$LOCAL_MLIR/include -Iinclude/Hello >> HelloOps.defs.h
手写这些代码还是写个 .td 自动生成。
对 mlir-hello 项目的源代码文件 HelloOps.td,通过自定义的 .td 文件声明式的语法可以在新的 Dialect 中便捷的定义新的 Op。

参考文献链接
https://mp.weixin.qq.com/s/IpU6KP6igEgP4mCYUnYTyQ
https://mp.weixin.qq.com/s/8B9A5Pu9mb2ooOOzt3quaQ
mlir-hello: https://github.com/Lewuathe/mlir-hello
TableGen: https://llvm.org/docs/TableGen/ProgRef.html
mlir-hello: https://github.com/Lewuathe/mlir-hello
TableGen: https://llvm.org/docs/TableGen/ProgRef.html
Operation Definition Specification (ODS): https://mlir.llvm.org/docs/OpDefinitions

posted @ 2023-03-03 04:31  吴建明wujianming  阅读(480)  评论(0编辑  收藏  举报