AI编译器CINN v.s TVM 中CodeGen 源码解读


一、CINN 框架

CINN 中CodeGen之后的代码编译主要交给了Compiler类来负责。核心的函数主要是:

  • Build(ir::Module&, string& code)
  • Lookup(string& fn_name)
class Compiler final {
  static std::unique_ptr<Compiler> Create(const Target& target) {
    return std::unique_ptr<Compiler>(new Compiler(target));

   * Compile and link to a CINN module.
  void Build(const ir::Module& module, const std::string& code = "");

  void ExportObject(const std::string& path);

  std::string GetSourceCode(const ir::Module& module);

  void BuildDefault(const ir::Module& module);

   * Retrieve a function by \p fn_name.
   * @return function address or null if not exists.
  void* Lookup(absl::string_view fn_name);

  void CompileCudaModule(const ir::Module& module, const std::string& code = "");

  void CompileX86Module(const ir::Module& module);

  explicit Compiler(const Target& target) : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}


  Target target_;
  std::unique_ptr<ExecutionEngine> engine_;

  std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;


void Compiler::Build(const Module& module, const std::string& code) {
  if (target_.arch == Target::Arch::NVGPU) {
    CompileCudaModule(module, code);      // <----- GPU 上编译逻辑
  } else if (target_.arch == Target::Arch::X86) {
    CompileX86Module(module);            // <------ X86 CPU 上编译逻辑
  } else {

我们来详细研习下 CompileCudaModule() 函数的实现逻辑:

  • step 1: 调用 SplitCudaAndHostModule()ir::Module 切分成 host_moduledevice_module
  • step 2: 借助 CodeGenCUDA_Dev 模块,对 device_module 进行代码生成,得到 source_code;也支持用户直接通过code参数指定
  • step 3: 构造一个 nvrtc::Compiler 对象,将 source_code编译为 ptx 中间代码
  • step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是 CUBIN 或者 PTXkind 类型
  • step 5: 根据 device_module 中的 fn_namecuda_module_ 中取出对应的fn_kernelCUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__void*)
  • step 6: 以 RuntimeSymbols 构造 ExecutionEngine,负责将 CUDA KernelHost 端API进行链接
void Compiler::CompileCudaModule(const Module& module, const std::string& code) {
  // step 1: 调用SplitCudaAndHostModule()将ir::Module切分成 host_module和device_module
  auto _host_module_device_module_ = SplitCudaAndHostModule(module);  // NOLINT
  auto& host_module                = std::get<0>(_host_module_device_module_);
  auto& device_module              = std::get<1>(_host_module_device_module_);
  // step 2: 借助CodeGenCUDA_Dev模块,对device_module进行代码生成,得到 source_code;也支持用户直接通过code参数指定
  std::string source_code;
  CodeGenCUDA_Dev codegen(target_);
  source_code = codegen.Compile(device_module);
  // step 3: 构造一个nvrtc::Compiler对象,将source_code编译为ptx中间代码
  backends::nvrtc::Compiler compiler;
  auto ptx = compiler(source_code);
  // step 4: 以ptx构造一个runtime::cuda::CUDAModule对象,可以选择是CUBIN 或者 PTX 的kind类型
  using runtime::cuda::CUDAModule;
      new CUDAModule(ptx, compiler.compile_to_cubin() ? CUDAModule::Kind::CUBIN : CUDAModule::Kind::PTX));
  // step 5: 根据device_module中的fn_name在cuda_module_中取出对应的fn_kernel(CUfunction对象),并在RuntimeSymbols中注册(fn_name__ptr__、void*)
  RuntimeSymbols symbols;
  for (auto& fn : device_module.functions()) {
    std::string kernel_fn_name = fn->name;
    auto fn_kernel             = cuda_module_->GetFunction(0, kernel_fn_name);
    symbols.RegisterVar(kernel_fn_name + "_ptr_", reinterpret_cast<void*>(fn_kernel));
  // step 6: 以RuntimeSymbols构造ExecutionEngine,负责将CUDA Kernel 与 Host 端API进行链接
  engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));


二、TVM 框架

runtime::Module BuildCUDA(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenCUDA cg;

  for (auto kv : mod->functions) {
    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
        << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";

  std::string code = cg.Finish();

  if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
    code = (*f)(code).operator std::string();
  std::string fmt = "ptx";
  std::string ptx;
  const auto* f_enter = Registry::Get("target.TargetEnterScope");
  if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {   // <---- 可以借助python端的函数,注册一个直接编译 cubin 的函数
    ptx = (*f)(code).operator std::string();
    // Dirty matching to check PTX vs cubin.
    // TODO(tqchen) more reliable checks
    if (ptx[0] != '/') fmt = "cubin";
  } else {
    ptx = NVRTCCompile(code, cg.need_include_path());
  const auto* f_exit = Registry::Get("target.TargetExitScope");
  return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);


其中 NVRTCCompile 的实现:

std::string NVRTCCompile(const std::string& code, bool include_path = false) {
  std::vector<std::string> compile_params;
  std::vector<const char*> param_cstrings{};
  nvrtcProgram prog;
  std::string cc = "30";
  int major, minor;
  cudaError_t e1 = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, 0);
  cudaError_t e2 = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, 0);

  if (e1 == cudaSuccess && e2 == cudaSuccess) {
    cc = std::to_string(major) + std::to_string(minor);
  } else {
    LOG(WARNING) << "cannot detect compute capability from your device, "
                 << "fall back to compute_30.";

  compile_params.push_back("-arch=compute_" + cc);

  if (include_path) {
    std::string include_option = "--include-path=" + FindCUDAIncludePath();


  for (const auto& string : compile_params) {
  NVRTC_CALL(nvrtcCreateProgram(&prog, code.c_str(), nullptr, 0, nullptr, nullptr));
  nvrtcResult compile_res = nvrtcCompileProgram(prog, param_cstrings.size(),;

  size_t log_size;
  NVRTC_CALL(nvrtcGetProgramLogSize(prog, &log_size));
  std::string log;
  NVRTC_CALL(nvrtcGetProgramLog(prog, &log[0]));
  ICHECK_EQ(compile_res, NVRTC_SUCCESS) << log;
  size_t ptx_size;
  NVRTC_CALL(nvrtcGetPTXSize(prog, &ptx_size));

  std::string ptx;
  NVRTC_CALL(nvrtcGetPTX(prog, &ptx[0]));

  return ptx;

可以在 python 端注册一个自定义函数,直接编译为 cubin 文件,用法见TVM中的单测:

def tvm_callback_cuda_compile(code):
    """use nvcc to generate fatbin code for better optimization"""
    ptx = compile_cuda(code, target_format="fatbin")
    return ptx

def compile_cuda(code, target_format="ptx", arch=None, options=None, path_target=None):
    """Compile cuda code with NVCC from env
    if arch is None:
        # If None, then it will use ``.
        # Target arch could be a str like "sm_xx", or a list, such as
        # [
        #   "-gencode", "arch=compute_52,code=sm_52",
        #   "-gencode", "arch=compute_70,code=sm_70"
        # ]
        compute_version = "".join(
        arch = ["-gencode", f"arch=compute_{compute_version},code=sm_{compute_version}"]

    temp = utils.tempdir()
    if target_format not in ["cubin", "ptx", "fatbin"]:
        raise ValueError("target_format must be in cubin, ptx, fatbin")
    temp_code = temp.relpath("")
    temp_target = temp.relpath("my_kernel.%s" % target_format)

    with open(temp_code, "w") as out_file:

    file_target = path_target if path_target else temp_target
    cmd = ["nvcc"]
    cmd += ["--%s" % target_format, "-O3"]
    if isinstance(arch, list):
        cmd += arch
    elif isinstance(arch, str):
        cmd += ["-arch", arch]

    if options:
        if isinstance(options, str):
            cmd += [options]
        elif isinstance(options, list):
            cmd += options
            raise ValueError("options must be str or list of str")

    cmd += ["-o", file_target]
    cmd += [temp_code]

    # NOTE: ccbin option can be used to tell nvcc where to find the c++ compiler
    # just in case it is not in the path. On Windows it is not in the path by default.
    # However, we cannot use TVM_CXX_COMPILER_PATH because the runtime env.
    # Because it is hard to do runtime compiler detection, we require nvcc is configured
    # correctly by default.
    # if cxx_compiler_path != "":
    #    cmd += ["-ccbin", cxx_compiler_path]

    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

    (out, _) = proc.communicate()

    if proc.returncode != 0:
        msg = code
        msg += "\nCompilation error:\n"
        msg += py_str(out)
        raise RuntimeError(msg)

    with open(file_target, "rb") as f:
        data = bytearray(
        if not data:
            raise RuntimeError("Compilation error: empty result is generated")
        return data


