浅析 Triton 执行流程

本博客原文地址:https://www.cnblogs.com/BobHuang/p/18324040,原文阅读体验更佳。

上一篇文章介绍了MLIR及其Pass的添加,受到了很多朋友的点赞支持,非常感谢。

Triton作者关于其设计的论文发表在MAPL2019,当前项目下首次commit为6d7cf35,为2021.07 push,最初的源码在isaac

笔者关注到Triton是2023年5月,当时我在将我们自有的gpgpu芯片接入Pytorch,希望完成对大模型的训练。Triton是Pytorch2.0更新中Inductor的重要部分,Triton可以作为Python dsl快速完成算子,抛弃掉cuda,并且能较快支持一个新硬件。关于Pytorch接入新硬件可以参考自有AI芯片接入AI框架Pytorch的方案,我有点想弄个triton的OpenCL后端玩具。

另外如果没有Nvidia显卡的实体机可以看看腾讯云HAI以租代买,按时计费,关机不计费,我这边测试1.2元/小时的基础型使用的卡是Tesla T4, Turing架构,有tensor core, 3.6元/小时的进阶型使用的卡是Tesla V100-SXM2

一、什么是Triton

Triton是一个关心tile(分块)的编译器和语言设施,最初其对C做了扩展,现在是Pythonic且使用MLIR作为基础设施,即用户写kernel直接用Python就可以。灵活性非常高,用户不太了解cuda的情况下也可以写出性能还不错的kernel,可以有效降低模型开发时算子的编写时间,大力出奇迹。Triton我认为有以下三个非常重要的特性。

1、粒度为Block(tile)

在Triton代码的编写中我们更关心一个Block,也就是SM(Streaming Multiprocessor,计算单元),而不是cuda中对Thread、Block、Grid严格合理组织,用户甚至感知不到shared memory。若直接关注Tensor这个更高层次粒度的粒度,则不能有效得利用gpgpu的架构特性。这个层级的选择是非常合理的,编译器能很好得解决这层trade-off,能够减轻用户写kernel的负担,也能保证一定的性能和灵活性,可以无限接近cuda手写算子。
关于cuda和Triton的对比,单反相机和智能手机的比较是广泛流传的,见下图,引用地址

Pytorch Conference 2023 Thomas Raoux 的 Talk: Triton Compiler 也做了对比,3:06开始,pdf地址

2、优化Pass

Triton潜在表述中其还是编译器,借助一系列的优化Pass,他可以达到和手写算子接近的水平。

杨军大佬提到有些Pass很重要 用于辅助向量化访存的coalescing、用于缓解计算访存差异的pipeline/prefetch,用于避免shared memory访问bank-conflict的swizzling引用地址

杨军大佬还提到了 Triton的Layout抽象目前主要包括Blocked Encoding、Shared Encoding、MMA Encoding、DotOperand Encoding、Slice Encoding这几类定义。 内存抽象对于优化是重要的,同样重要的还有对新特性的支持。

更新了比较重要的 LinearLayout,还把LLs和cute做了对比,二球大佬的文章

智源对重要的pass也做了简单介绍,清华智源 Triton中国生态Meetup(第一期)暨首次Triton中国社区活动 Slides pages 130

3、Pytorch集成

[Pytorch 2.40]里支持的硬件依旧有限,华为的npu依旧要自己去维护Pytorch库torch_npucodegen/gen_backend_stubs.py:210true_backend = 'PrivateUse1' if yaml_backend == 'NPU' else yaml_backend 表明现在官方把PrivateUse1完善了,不需要借用XLA了,但每接入一个新硬件都要维护这样一份代码,还要维护算子并调优。更多厂商的接入见自有AI芯片接入AI框架Pytorch的方案

Pytorch 2.0 引入了Inductor,我们也可以以编译器的方式去接入到Pytorch,Triton作为一个后端出现了,当然Triton写的算子也可以以默认的eger模式接入。那我们为什么不直接用triton算子呢,如果是一个异构的gpgpu模型,写kernel关注的依旧是tile,这样最起码我们可以把项目跑起来去扩充生态,有着还不错的性能,也为后期的优化提供了接口,甚至不同的硬件可以是相同的Triton代码。那么能不能用Triton去挽救下即将死去的OpenCL生态呢,个人感觉还是蛮有意思的,当然能接到SPIR-V也可以。如果不是异构模型,triton-cpu cpu后端也可以关注下。

当然蓝色大佬在Baby Triton 提出可以直接生成出CUDA代码方便算法工程师调优也值得关注,有这样一个插件确实挺方便的。那么问题来了,Triton要不要去掉Pytorch依赖成为更完整的语言。

二、Triton的执行流程

在Triton中我们编写的是Python代码,最后得到的是ptx(ptx后用的就是nv工具了)。大致流程为对Python的AST(抽象语法树)遍历获得描述的Triton-IR,再经过Triton编译器将其降级为LLVM-IR,再通过LLVM后端输出ptx。我们将按这个流程以向量加为例将其分析下。

1、编译Triton的debug版本

我们需要编译Triton的debug版本方便我们调试,可以直接按照README中去运行。我这里用的Triton的commit为adee21f Aug 2, 2024 的版本,笔者运行的命令如下所示,可以参考。

# 克隆 triton项目
git clone https://github.com/triton-lang/triton
cd triton
# 可以checkout到教程所用版本方便后面对代码和行号
# git checkout adee21f52a00329e14b1bf7dc32ccbe2f5b03064
# 查看其依赖的LLVM版本
cat cmake/llvm-hash.txt
# 依赖LLVM版本号为4713bd4ccc0c0d568f92916e7851d993291742c0

cd ~
# 克隆LLVM
git clone https://github.com/llvm/llvm-project
# 切换到 cat 出的依赖的LLVM版本
git checkout 4713bd4ccc0c0d568f92916e7851d993291742c0
mkdir build;cd build
# 设置CMake 参数
cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
# 编译LLVM,需要蛮久的
ninja

# 进入Triton文件夹
cd <triton install>
export LLVM_BUILD_DIR=~/llvm-project/build
# 设置debug模式
export DEBUG=1
# 调用 python/setup.py 安装
LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \
  LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \
  LLVM_SYSPATH=$LLVM_BUILD_DIR \
  pip install -e python

2、Triton 代码编写

Triton的源码需要使用@triton.jit装饰器,用来标记这是一段Triton kernel函数,使其能够被JIT(即时编译)编译并在GPU上运行。我们需要将申请好的输入输出tensor指针传递下去,还有元素个数以及每个程序应该处理的元素数量,在Triton中,我们会把block抽象为程序,即1个block将作为1个程序。

下面写起来就像cuda,我们需要确定当前block需要进行的计算,由于我们这个是一维的,我们求出程序ID即pid,直接乘上BLOCK_SIZE就可以得到我们这个程序需要开始算的数据的开始block_start,然后对[0, BLOCK_SIZE)依次操作就可以了,可以表示为列表 block_start + tl.arange(0, BLOCK_SIZE)offsets
向量加也很简单,使用tl.load加上访问的地址,分别是x_ptr + offsetsy_ptr + offsets,最后两者相加,然后再写回output_ptr + offsets即可。

为了程序的健壮性,如果向量的长度不是BLOCK_SIZE的倍数,我们可以加上mask,也就是判断offsetsn_elements的大小关系,在存取的时候都加上mask就可以避免越界了。

@triton.jit
def add_kernel(x_ptr,  # 指向第一个输入向量的指针
               y_ptr,  # 指向第二个输入向量的指针
               output_ptr,  # 指向输出向量的指针
               n_elements,  # 向量的长度
               BLOCK_SIZE: tl.constexpr,  # 每个程序应处理的元素数量,triton将block抽象为程序
               # 注意: `constexpr` 可以用作形状值.
               ):
    # 有多个“程序”在处理不同的数据。我们在这里确定我们是哪个程序:
    pid = tl.program_id(axis=0)  # 我们以1D网格启动 所以 axis 是 0.
    # 该程序将处理从初始数据偏移的输入。
    # 例如,如果你有一个长度为256且块大小为64的向量,程序 将分别访问元素 [0:64), [64:128), [128:192), [192:256)。
    block_start = pid * BLOCK_SIZE
    # 注意,offsets 是一个指针列表
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建一个掩码以防止内存越界访问
    mask = offsets < n_elements
    # 从 DRAM 加载 x 和 y,mask用来解决输入不是块大小的倍数而多余的元素
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # 将 x + y 写回 DRAM
    tl.store(output_ptr + offsets, output, mask=mask)

这里引用下robindu大佬对于triton的书写步骤的总结

1)分析并行性并拆分,也就是定义好grid,并明确每个program要完成的运算范围;2)根据范围计算index偏移,并将其转换为一维指针的偏移形式,然后将数据从DRAM中load;3)使用加载过的数据进行运算,如果运算范围较大,需要使用循环逐段完成。 引用地址

那如果是cuda源码呢,其实是更简洁的,但是当性能调优的时候需要考虑得更多。由于借助了Pytorch再加上自身实现,可以隐去部分runtime细节,就像在调用一个函数,Pytorch是在aten里帮你做了。

    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

3、AST 获取并转换为ttir

跟Triton语言相关的代码在python/triton/language文件夹。
通过pdb可以看到调用流程,大概如下所示。
先在JIT中python/triton/runtime/jit.py:648kernel = self.compile(调用编译器,再在编译器中python/triton/compiler/compiler.py:113return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)调用ast去生成ttir,即为codeGen 的遍历,从python/triton/compiler/code_generator.py:1180ret = super().visit(node)跳进了Python AST,再visitor回来。
如果是加法,那是BinOp,将通过python/triton/compiler/code_generator.py:520def visit_BinOp(self, node)函数到我们python/triton/language/core.py来。通过wrapper包装器将其对应到不同的builtin。加法是__add__,调用semantic(语义分析器)来生成代码,具体的builder全在python/triton/language/semantic.py下。比如我们代码中的output = x + y就是一个float+float,对应为python/triton/language/semantic.py:148return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)。当然create_fadd对应的是c++代码python/src/ir.cc:989return self.create<arith::AddFOp>(lhs, rhs);

经过ast得到的最初的ttir如下所示。

module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
    %0 = tt.get_program_id x : i32 loc(#loc1)
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc2)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc3)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc4)
    %4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc4)
    %5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc5)
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc5)
    %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
    %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc7)
    %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc8)
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc8)
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc9)
    %13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc10)
    %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc11)
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc11)
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc12)
    tt.return loc(#loc13)
  } loc(#loc)
} loc(#loc)

接下来其还会在 python/triton/compiler/compiler.py:287 next_module = compile_ir(module, metadata)去调用make_ttir去运行一些优化Pass。third_party/nvidia/backend/compiler.py:162 内优化Pass集合如下所示

    def make_ttir(mod, metadata, opt):
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.common.add_inliner(pm)
        passes.ttir.add_rewrite_tensor_pointer(pm)
        passes.ttir.add_combine(pm)
        passes.common.add_canonicalizer(pm)
        passes.ttir.add_reorder_broadcast(pm)
        passes.common.add_cse(pm)
        passes.common.add_licm(pm)
        passes.common.add_symbol_dce(pm)
        pm.run(mod)
        return mod

向量加由于过于简单并没有发生变化,不然可能会有包括函数内联、合并操作、公共子表达式消除、死代码消除等作用。

4、根据runtime和平台信息生成ttgir

ttgirtriton gpu ir,同样是通过Pass集合去得到的。
third_party/nvidia/backend/compiler.py:177 内优化Pass集合如下所示

    def make_ttgir(mod, metadata, opt, capability):
        cluster_info = nvidia.ClusterInfo()
        if opt.cluster_dims is not None:
            cluster_info.clusterDimX = opt.cluster_dims[0]
            cluster_info.clusterDimY = opt.cluster_dims[1]
            cluster_info.clusterDimZ = opt.cluster_dims[2]
        # Set up Diagnostic
        if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
            srcMgr = llvm.source_mgr()
            diag = ir.source_mgr_diag(srcMgr, mod.context)
            mod.context.printOpOnDiagnostic(True)
        # TTIR -> TTGIR
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
        # optimize TTGIR
        passes.ttgpuir.add_coalesce(pm)
        if capability // 10 >= 8:
            passes.ttgpuir.add_f32_dot_tc(pm)
        # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
        nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_thread_locality(pm)
        passes.ttgpuir.add_accelerate_matmul(pm)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        passes.common.add_cse(pm)
        if capability // 10 >= 8:
            passes.ttgpuir.add_combine_tensor_select_and_if(pm)
            passes.ttgpuir.add_pipeline(pm, opt.num_stages)
        passes.ttgpuir.add_prefetch(pm)
        passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
        passes.ttgpuir.add_remove_layout_conversions(pm)
        passes.ttgpuir.add_reduce_data_duplication(pm)
        passes.ttgpuir.add_reorder_instructions(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        if capability // 10 >= 9:
            nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
            nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
        passes.common.add_canonicalizer(pm)
        pm.run(mod)
        metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
        return mod

这里我们看到不同的capability会有不同的Pass,capability // 10 >= 8Ampere及以上架构才有,像我实验所用的Tesla T47.5,就不会进这类Pass了,具体计算能力可以看nvidia文档

passes.ttgpuir.add_f32_dot_tc(pm) 这个添加的Pass在lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp。可以通过add_f32_dot_tcpython/src/passes.cc:57查到Pass的create函数为createTritonGPUF32DotTC,那么我们的Pass类就是TritonGPUF32DotTC,再搜索TritonGPUF32DotTC就可以找到这个Pass文件了。

那这个Pass是做什么的,注释里有如下描述,他是用来提高3xTF32点积计算精度的,将 FP32 数字分解为高位和低位部分,进行三次 TF32 点积计算来提高精度。那么为什么架构低一些的Tesla T4没有呢,因为没有cvt.rna.tf32.f32这条ptx汇编,这是Ampere架构及之后提供的 将FP32转换为TF32 的指令。具体转换方式见此Pass

// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
// dot(a, b, inputPrecision="tf32x3") ->
//  let aBig = f32ToTF32(a), aSmall = a - aBig;
//  let bBig = f32ToTF32(b), bSmall = b - bBig;
//  dot(aSmall, bBig, inputPrecision="tf32") +
//  dot(aBig, bSmall, inputPrecision="tf32") +
//  dot(aBig, bBig, inputPrecision="tf32")

在最新的H100中,还提供了add_fence_insertionadd_tma_loweringPass,分别是用于异步共享内存操作之间插入fence和加速Tensor Core计算中的内存访问,这个其实4月底才更新进来。杨军大佬说的Hopper 大作业 commit

经过这些Pass集合可以得到我个计算平台的gpu ir,如下所示,向量加比较简单,只是被打上了#blocked的Attr标记。

#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#loc = loc("python/tutorials/01-vector-add.py":28:0)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:75", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
    %c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> loc(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> loc(#loc5)
    %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> loc(#loc5)
    %5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> loc(#loc6)
    %6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> loc(#loc6)
    %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc7)
    %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc7)
    %9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc8)
    %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc9)
    %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc9)
    %12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc10)
    %13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
    %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc12)
    %15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc12)
    tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
    tt.return loc(#loc14)
  } loc(#loc)
} loc(#loc)

5、lower到LLVMIR

third_party/nvidia/backend/compiler.py:223 展示了降级到LLVM IR 需要的操作,关于自身op的特殊转换在third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp,这个Pass调用了一些MLIR的方法而且是Pattern,虽然代码行数不多但是比较复杂。关于内嵌汇编处理在third_party/nvidia/lib/NVGPUToLLVM/TritonGPUToLLVM.cpp

    def make_llir(src, metadata, options, capability):
        # warp-specialization mutates num_warps
        num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
        if num_warp_groups is not None:
            metadata["num_warps"] *= num_warp_groups
        mod = src
        # TritonGPU -> LLVM-IR (MLIR)
        pm = ir.pass_manager(mod.context)
        pm.enable_debug()
        nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
        passes.ttgpuir.add_combine_tensor_select_and_if(pm)
        passes.convert.add_scf_to_cf(pm)
        passes.convert.add_index_to_llvmir(pm)
        passes.ttgpuir.add_allocate_shared_memory(pm)
        nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
        nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
        passes.convert.add_arith_to_llvmir(pm)
        passes.common.add_canonicalizer(pm)
        passes.common.add_cse(pm)
        passes.common.add_symbol_dce(pm)
        if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
            passes.llvmir.add_di_scope(pm)
        pm.run(mod)
        # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
        llvm.init_targets()
        context = llvm.context()

        llvm_mod = llvm.to_module(mod, context)
        proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
        features = get_features(options)
        triple = 'nvptx64-nvidia-cuda'
        llvm.attach_datalayout(llvm_mod, triple, proc, features)
        nvidia.set_nvvm_reflect_ftz(llvm_mod)

        # Set maxnreg on all kernels, if it was provided.
        if options.maxnreg is not None:
            for k in llvm_mod.get_functions():
                if not k.is_declaration() and k.is_external_linkage():
                    k.set_nvvm_maxnreg(options.maxnreg)

        if options.extern_libs:
            paths = [path for (name, path) in options.extern_libs]
            llvm.link_extern_libs(llvm_mod, paths)

        llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)

        # Get some metadata
        metadata["shared"] = src.get_int_attr("triton_gpu.shared")
        ret = str(llvm_mod)
        del llvm_mod
        del context
        return ret

经过一系列Pass降级就可以到LLVM IR

6、lower到ptx

以上均通过PythonPassManager去管理代码,从这里开始就不是了。third_party/nvidia/backend/compiler.py:277 make_ptx如下所示,因为这部分代码走LLVM后端,不需要把他们引入到项目。

    def make_ptx(src, metadata, opt, capability):
        ptx_version = opt.ptx_version
        if ptx_version is None:
            _, cuda_version = _path_to_binary("ptxas")
            ptx_version = ptx_get_version(cuda_version)

        triple = 'nvptx64-nvidia-cuda'
        proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
        features = get_features(opt)
        ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
        # Find kernel names (there should only be one)
        names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
        assert len(names) == 1
        metadata["name"] = names[0]
        # post-process
        ptx_version = f'{ptx_version//10}.{ptx_version%10}'
        ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
        # Remove the debug flag that prevents ptxas from optimizing the code
        ret = re.sub(r",\s*debug|debug,\s*", "", ret)
        if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
            print("// -----// NVPTX Dump //----- //")
            print(ret)
        return ret

调用 llvm.translate_to_asm做了lower,下面就是一些正则的替换。translate_to_asm实际为 python/src/llvm.cc:42translateLLVMIRToASM函数。

7、生成cubin

最后生成二进制文件就大功告成了,用ptxas即可。代码如下所示

    def make_cubin(src, metadata, opt, capability):
        ptxas, _ = _path_to_binary("ptxas")
        with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
            tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
            fsrc.write(src)
            fsrc.flush()
            fbin = fsrc.name + '.o'

            line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo']
            fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
            suffix = 'a' if capability == 90 else ''
            opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
            ptxas_cmd = [
                ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
            ]
            try:
                subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)
                if os.path.exists(fsrc.name):
                    os.remove(fsrc.name)
                if os.path.exists(flog.name):
                    os.remove(flog.name)
            except subprocess.CalledProcessError as e:
                with open(flog.name) as log_file:
                    log = log_file.read()
                if os.path.exists(flog.name):
                    os.remove(flog.name)

                if e.returncode == 255:
                    error = 'Internal Triton PTX codegen error'
                elif e.returncode == 128 + signal.SIGSEGV:
                    error = '`ptxas` raised SIGSEGV'
                else:
                    error = f'`ptxas` failed with error code {e.returncode}'

                raise RuntimeError(f'{error}\n'
                                   f'`ptxas` stderr:\n{log}\n'
                                   f'Repro command: {ptxas_cmd}\n')

            with open(fbin, 'rb') as f:
                cubin = f.read()
            if os.path.exists(fbin):
                os.remove(fbin)
        return cubin

生成通过subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)转成二进制文件,编译的活终于干完了。

8、配合runtime运行

通过runtime跑起来也很简单,只需要把metadata元数据准备好再run就好了,python/triton/runtime/jit.py:676 launch kernel代码如下所示

            launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
            kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
                       self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)

会调用到python/triton/compiler/compiler.py的runner,根据launch_enter_hooklaunch_exit_hook这2个钩子来控制,最终调用的是third_party/nvidia/backend/driver.pycuLaunchKernelExHandle

三、Triton运行示列

Triton自己提供了一些示例,flash-attention、xFormers、lightllm均对Triton做了支持。

1、Triton tutorials

上文的向量加就取于Triton的python/tutorials/01-vector-add.py,他还提供了包括fused-attentionfused-softmaxgrouped-gemm在内的示例。

部分示例使用了@triton.autotune,即自动优化参数针对不同的硬件选择最优的配置。其内的benchmark也和torch内置的以及cuBLAS等做了比较,基本达到了各有胜负的水平。此外triton还在积极得将arithmeticmathematics融入进来,tl.math里有蛮多API了。

2、LightLLM

LightLLM是商汤的一个推理和服务框架,现在已经支持了不少模型。不同的模型会有自己的triton_kernel,比如DeepSeek-V2llamamistral等等。

3、FlagGems

FlagGems是清华智源推出的通用算子库,力求做到高性能广覆盖轻量级,这里轻量级是最重要的,他希望能够多款硬件支持,涵盖NVIDIA与国产芯片。
目前有以下厂商在适配FlagGems,图片来源 pages17

4、flash-attention

flash-attention根据Tritonpython/tutorials/06-fused-attention.py 也做了一个支持attention bias的实现,flash_attn/flash_attn_triton.py

诸如mpt等模型对其有使用,使用也很简单,from flash_attn.flash_attn_triton import flash_attn_func import 进来直接使用 flash_attn_func 即可,详见源码attention.py:L172,不过给我提示了If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton",可能是其在前缀语言模型效果才好。

5、xFormers

xFormersfacebookresearch用于高效和灵活的 Transformer 模型的库,旨在提高深度学习框架的性能。
他们自己的dinov2一个计算机视觉自监督模型用了xFormersmemory_efficient_attention

6、linkedin/Liger-Kernel

高性能的用于LLM 训练的Triton kernel

7、inccat/Awesome-Triton-Kernels

是一个汇总

参考

1.杨军 谈谈对OpenAI Triton的一些理解

2.杨军 OpenAI Triton Conference参会随感兼谈Triton Hopper

3.官方文档

4.OpenAI Introducing Triton: Open-source GPU programming for neural networks

5.蓝色 一起实现一个Baby Triton

6.TanyoKwok 郭天佑​ 聊聊 PyTorch 2.0(Inductor)

7.王钧 Triton学习笔记

8.科研败犬丶 OpenAI/Triton MLIR 第一章: Triton DSL

9.清华智源 FlagGems

10.清华智源 基于Triton的大模型通用算子库FlagGems技术分享 Slides

11.清华智源 Triton中国生态Meetup(第一期)暨首次Triton中国社区活动 Slides

12.微软 triton-shared

13.寒武纪 triton-linalg

14.robindu lightllm代码解读番外篇——triton kernel撰写

15.二球 CuTe and Triton (施工中)

posted @ 2024-07-25 20:01  暴力都不会的蒟蒻  阅读(2194)  评论(0编辑  收藏  举报