浅析 Triton 执行流程
本博客原文地址:https://www.cnblogs.com/BobHuang/p/18324040,原文阅读体验更佳。
上一篇文章介绍了MLIR及其Pass的添加,受到了很多朋友的点赞支持,非常感谢。
Triton作者关于其设计的论文发表在MAPL2019,当前项目下首次commit为6d7cf35,为2021.07 push,最初的源码在isaac。
笔者关注到Triton是2023年5月,当时我在将我们自有的gpgpu芯片接入Pytorch,希望完成对大模型的训练。Triton是Pytorch2.0更新中Inductor的重要部分,Triton可以作为Python dsl快速完成算子,抛弃掉cuda,并且能较快支持一个新硬件。关于Pytorch接入新硬件可以参考自有AI芯片接入AI框架Pytorch的方案,我有点想弄个triton的OpenCL后端玩具。
另外如果没有Nvidia显卡的实体机可以看看腾讯云HAI以租代买,按时计费,关机不计费,我这边测试1.2元/小时的基础型使用的卡是Tesla T4
, Turing架构,有tensor core, 3.6元/小时的进阶型使用的卡是Tesla V100-SXM2
。
一、什么是Triton
Triton是一个关心tile(分块)的编译器和语言设施,最初其对C做了扩展,现在是Pythonic且使用MLIR作为基础设施,即用户写kernel直接用Python就可以。灵活性非常高,用户不太了解cuda的情况下也可以写出性能还不错的kernel,可以有效降低模型开发时算子的编写时间,大力出奇迹。Triton我认为有以下三个非常重要的特性。
1、粒度为Block(tile)
在Triton代码的编写中我们更关心一个Block,也就是SM(Streaming Multiprocessor,计算单元),而不是cuda中对Thread、Block、Grid严格合理组织,用户甚至感知不到shared memory。若直接关注Tensor
这个更高层次粒度的粒度,则不能有效得利用gpgpu的架构特性。这个层级的选择是非常合理的,编译器能很好得解决这层trade-off,能够减轻用户写kernel的负担,也能保证一定的性能和灵活性,可以无限接近cuda手写算子。
关于cuda和Triton的对比,单反相机和智能手机的比较是广泛流传的,见下图,引用地址。
Pytorch Conference 2023 Thomas Raoux 的 Talk: Triton Compiler 也做了对比,3:06开始,pdf地址
2、优化Pass
Triton潜在表述中其还是编译器,借助一系列的优化Pass,他可以达到和手写算子接近的水平。
杨军大佬提到有些Pass很重要 用于辅助向量化访存的coalescing、用于缓解计算访存差异的pipeline/prefetch,用于避免shared memory访问bank-conflict的swizzling。 引用地址
杨军大佬还提到了 Triton的Layout抽象目前主要包括Blocked Encoding、Shared Encoding、MMA Encoding、DotOperand Encoding、Slice Encoding这几类定义。 内存抽象对于优化是重要的,同样重要的还有对新特性的支持。
智源对重要的pass也做了简单介绍,清华智源 Triton中国生态Meetup(第一期)暨首次Triton中国社区活动 Slides pages 130
3、Pytorch集成
[Pytorch 2.40]里支持的硬件依旧有限,华为的npu依旧要自己去维护Pytorch库torch_npu,codegen/gen_backend_stubs.py:210true_backend = 'PrivateUse1' if yaml_backend == 'NPU' else yaml_backend
表明现在官方把PrivateUse1
完善了,不需要借用XLA
了,但每接入一个新硬件都要维护这样一份代码,还要维护算子并调优。更多厂商的接入见自有AI芯片接入AI框架Pytorch的方案。
Pytorch 2.0 引入了Inductor,我们也可以以编译器的方式去接入到Pytorch,Triton作为一个后端出现了,当然Triton写的算子也可以以默认的eger
模式接入。那我们为什么不直接用triton算子呢,如果是一个异构的gpgpu模型,写kernel关注的依旧是tile,这样最起码我们可以把项目跑起来去扩充生态,有着还不错的性能,也为后期的优化提供了接口,甚至不同的硬件可以是相同的Triton代码。那么能不能用Triton去挽救下即将死去的OpenCL
生态呢,个人感觉还是蛮有意思的,当然能接到SPIR-V
也可以。如果不是异构模型,triton-cpu cpu后端也可以关注下。
当然蓝色大佬在Baby Triton 提出可以直接生成出CUDA代码方便算法工程师调优也值得关注,有这样一个插件确实挺方便的。那么问题来了,Triton要不要去掉Pytorch依赖成为更完整的语言。
二、Triton的执行流程
在Triton中我们编写的是Python代码,最后得到的是ptx。大致流程为对Python的AST(抽象语法树)遍历获得描述的Triton-IR,再经过Triton
编译器将其降级为LLVM-IR
,再通过LLVM
后端输出ptx
。我们将按这个流程以向量加为例将其分析下。
1、编译Triton的debug版本
我们需要编译Triton的debug版本方便我们调试,可以直接按照README中去运行。我这里用的Triton的commit为adee21f Aug 2, 2024 的版本,笔者运行的命令如下所示,可以参考。
# 克隆 triton项目
git clone https://github.com/triton-lang/triton
cd triton
# 可以checkout到教程所用版本方便后面对代码和行号
# git checkout adee21f52a00329e14b1bf7dc32ccbe2f5b03064
# 查看其依赖的LLVM版本
cat cmake/llvm-hash.txt
# 依赖LLVM版本号为4713bd4ccc0c0d568f92916e7851d993291742c0
cd ~
# 克隆LLVM
git clone https://github.com/llvm/llvm-project
# 切换到 cat 出的依赖的LLVM版本
git checkout 4713bd4ccc0c0d568f92916e7851d993291742c0
mkdir build;cd build
# 设置CMake 参数
cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" -DLLVM_TARGETS_TO_BUILD="host;NVPTX;AMDGPU"
# 编译LLVM,需要蛮久的
ninja
# 进入Triton文件夹
cd <triton install>
export LLVM_BUILD_DIR=~/llvm-project/build
# 设置debug模式
export DEBUG=1
# 调用 python/setup.py 安装
LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \
LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \
LLVM_SYSPATH=$LLVM_BUILD_DIR \
pip install -e python
2、Triton 代码编写
Triton的源码需要使用@triton.jit
装饰器,用来标记这是一段Triton kernel函数,使其能够被JIT(即时编译)编译并在GPU上运行。我们需要将申请好的输入输出tensor指针传递下去,还有元素个数以及每个程序应该处理的元素数量,在Triton中,我们会把block抽象为程序,即1个block将作为1个程序。
下面写起来就像cuda,我们需要确定当前block需要进行的计算,由于我们这个是一维的,我们求出程序ID即pid
,直接乘上BLOCK_SIZE
就可以得到我们这个程序需要开始算的数据的开始block_start
,然后对[0, BLOCK_SIZE)
依次操作就可以了,可以表示为列表 block_start + tl.arange(0, BLOCK_SIZE)
即offsets
。
向量加也很简单,使用tl.load
加上访问的地址,分别是x_ptr + offsets
和y_ptr + offsets
,最后两者相加,然后再写回output_ptr + offsets
即可。
为了程序的健壮性,如果向量的长度不是BLOCK_SIZE的倍数,我们可以加上mask,也就是判断offsets
和n_elements
的大小关系,在存取的时候都加上mask就可以避免越界了。
@triton.jit
def add_kernel(x_ptr, # 指向第一个输入向量的指针
y_ptr, # 指向第二个输入向量的指针
output_ptr, # 指向输出向量的指针
n_elements, # 向量的长度
BLOCK_SIZE: tl.constexpr, # 每个程序应处理的元素数量,triton将block抽象为程序
# 注意: `constexpr` 可以用作形状值.
):
# 有多个“程序”在处理不同的数据。我们在这里确定我们是哪个程序:
pid = tl.program_id(axis=0) # 我们以1D网格启动 所以 axis 是 0.
# 该程序将处理从初始数据偏移的输入。
# 例如,如果你有一个长度为256且块大小为64的向量,程序 将分别访问元素 [0:64), [64:128), [128:192), [192:256)。
block_start = pid * BLOCK_SIZE
# 注意,offsets 是一个指针列表
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 创建一个掩码以防止内存越界访问
mask = offsets < n_elements
# 从 DRAM 加载 x 和 y,mask用来解决输入不是块大小的倍数而多余的元素
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
# 将 x + y 写回 DRAM
tl.store(output_ptr + offsets, output, mask=mask)
这里引用下robindu
大佬对于triton的书写步骤的总结
1)分析并行性并拆分,也就是定义好grid,并明确每个program要完成的运算范围;2)根据范围计算index偏移,并将其转换为一维指针的偏移形式,然后将数据从DRAM中load;3)使用加载过的数据进行运算,如果运算范围较大,需要使用循环逐段完成。 引用地址
那如果是cuda
源码呢,其实是更简洁的,但是当性能调优的时候需要考虑得更多。由于借助了Pytorch再加上自身实现,可以隐去部分runtime细节,就像在调用一个函数,Pytorch是在aten里帮你做了。
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
3、AST 获取并转换为ttir
跟Triton语言相关的代码在python/triton/language文件夹。
通过pdb可以看到调用流程,大概如下所示。
先在JIT中python/triton/runtime/jit.py:648的kernel = self.compile(
调用编译器,再在编译器中python/triton/compiler/compiler.py:113 的return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
调用ast去生成ttir,即为codeGen 的遍历,从python/triton/compiler/code_generator.py:1180ret = super().visit(node)
跳进了Python AST,再visitor
回来。
如果是加法,那是BinOp
,将通过python/triton/compiler/code_generator.py:520def visit_BinOp(self, node)
函数到我们python/triton/language/core.py
来。通过wrapper
包装器将其对应到不同的builtin
。加法是__add__
,调用semantic
(语义分析器)来生成代码,具体的builder全在python/triton/language/semantic.py
下。比如我们代码中的output = x + y
就是一个float+float,对应为python/triton/language/semantic.py:148return tl.tensor(builder.create_fadd(input.handle, other.handle), input.type)
。当然create_fadd
对应的是c++代码python/src/ir.cc:989的return self.create<arith::AddFOp>(lhs, rhs);
。
经过ast得到的最初的ttir如下所示。
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
%0 = tt.get_program_id x : i32 loc(#loc1)
%c1024_i32 = arith.constant 1024 : i32 loc(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc2)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> loc(#loc3)
%3 = tt.splat %1 : i32 -> tensor<1024xi32> loc(#loc4)
%4 = arith.addi %3, %2 : tensor<1024xi32> loc(#loc4)
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32> loc(#loc5)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32> loc(#loc5)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc6)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc6)
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc7)
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc8)
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc8)
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc9)
%13 = arith.addf %9, %12 : tensor<1024xf32> loc(#loc10)
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>> loc(#loc11)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32> loc(#loc11)
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>> loc(#loc12)
tt.return loc(#loc13)
} loc(#loc)
} loc(#loc)
接下来其还会在 python/triton/compiler/compiler.py:287 next_module = compile_ir(module, metadata)
去调用make_ttir
去运行一些优化Pass。third_party/nvidia/backend/compiler.py:162 内优化Pass集合如下所示
def make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod
向量加由于过于简单并没有发生变化,不然可能会有包括函数内联、合并操作、公共子表达式消除、死代码消除等作用。
4、根据runtime和平台信息生成ttgir
ttgir
即triton gpu ir
,同样是通过Pass集合去得到的。
third_party/nvidia/backend/compiler.py:177 内优化Pass集合如下所示
def make_ttgir(mod, metadata, opt, capability):
cluster_info = nvidia.ClusterInfo()
if opt.cluster_dims is not None:
cluster_info.clusterDimX = opt.cluster_dims[0]
cluster_info.clusterDimY = opt.cluster_dims[1]
cluster_info.clusterDimZ = opt.cluster_dims[2]
# Set up Diagnostic
if os.environ.get("MLIR_ENABLE_REMARK", "0") == "1":
srcMgr = llvm.source_mgr()
diag = ir.source_mgr_diag(srcMgr, mod.context)
mod.context.printOpOnDiagnostic(True)
# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
# optimize TTGIR
passes.ttgpuir.add_coalesce(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_f32_dot_tc(pm)
# TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.common.add_cse(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
passes.ttgpuir.add_prefetch(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_reduce_data_duplication(pm)
passes.ttgpuir.add_reorder_instructions(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if capability // 10 >= 9:
nvidia.passes.ttnvgpuir.add_fence_insertion(pm)
nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
passes.common.add_canonicalizer(pm)
pm.run(mod)
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
return mod
这里我们看到不同的capability会有不同的Pass,capability // 10 >= 8
即Ampere
及以上架构才有,像我实验所用的Tesla T4
是7.5
,就不会进这类Pass了,具体计算能力可以看nvidia文档。
passes.ttgpuir.add_f32_dot_tc(pm)
这个添加的Pass在lib/Dialect/TritonGPU/Transforms/F32DotTC.cpp。可以通过add_f32_dot_tc
在python/src/passes.cc:57查到Pass的create函数为createTritonGPUF32DotTC
,那么我们的Pass类就是TritonGPUF32DotTC
,再搜索TritonGPUF32DotTC
就可以找到这个Pass文件了。
那这个Pass是做什么的,注释里有如下描述,他是用来提高3xTF32
点积计算精度的,将 FP32 数字分解为高位和低位部分,进行三次 TF32 点积计算来提高精度。那么为什么架构低一些的Tesla T4
没有呢,因为没有cvt.rna.tf32.f32
这条ptx汇编,这是Ampere
架构及之后提供的 将FP32转换为TF32 的指令。具体转换方式见此Pass。
// nb. We call the trick TF32x3 as C++ disallows varaibles starting with numbers
// Implement 3xTF32 trick https://github.com/NVIDIA/cutlass/discussions/385
// For a, b f32
// dot(a, b, inputPrecision="tf32x3") ->
// let aBig = f32ToTF32(a), aSmall = a - aBig;
// let bBig = f32ToTF32(b), bSmall = b - bBig;
// dot(aSmall, bBig, inputPrecision="tf32") +
// dot(aBig, bSmall, inputPrecision="tf32") +
// dot(aBig, bBig, inputPrecision="tf32")
在最新的H100中,还提供了add_fence_insertion
和add_tma_lowering
Pass,分别是用于异步共享内存操作之间插入fence和加速Tensor Core计算中的内存访问,这个其实4月底才更新进来。杨军大佬说的Hopper 大作业 commit
经过这些Pass集合可以得到我个计算平台的gpu ir
,如下所示,向量加比较简单,只是被打上了#blocked
的Attr标记。
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
#loc = loc("python/tutorials/01-vector-add.py":28:0)
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:75", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0), %arg3: i32 {tt.divisibility = 16 : i32} loc("python/tutorials/01-vector-add.py":28:0)) attributes {noinline = false} {
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> loc(#loc4)
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> loc(#loc5)
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> loc(#loc5)
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> loc(#loc6)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> loc(#loc6)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc7)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc7)
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc8)
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc9)
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc9)
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc10)
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc12)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked> loc(#loc12)
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
tt.return loc(#loc14)
} loc(#loc)
} loc(#loc)
5、lower到LLVMIR
third_party/nvidia/backend/compiler.py:223 展示了降级到LLVM IR 需要的操作,关于自身op的特殊转换在third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp,这个Pass调用了一些MLIR的方法而且是Pattern,虽然代码行数不多但是比较复杂。关于内嵌汇编处理在third_party/nvidia/lib/NVGPUToLLVM/TritonGPUToLLVM.cpp。
def make_llir(src, metadata, options, capability):
# warp-specialization mutates num_warps
num_warp_groups = src.get_int_attr("triton_gpu.num-warp-groups-per-cta")
if num_warp_groups is not None:
metadata["num_warps"] *= num_warp_groups
mod = src
# TritonGPU -> LLVM-IR (MLIR)
pm = ir.pass_manager(mod.context)
pm.enable_debug()
nvidia.passes.ttgpuir.add_decompose_unsupported_conversions(pm)
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.convert.add_scf_to_cf(pm)
passes.convert.add_index_to_llvmir(pm)
passes.ttgpuir.add_allocate_shared_memory(pm)
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability)
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
passes.convert.add_arith_to_llvmir(pm)
passes.common.add_canonicalizer(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
if os.environ.get("TRITON_DISABLE_LINE_INFO", "0") == "0":
passes.llvmir.add_di_scope(pm)
pm.run(mod)
# LLVM-IR (MLIR) -> LLVM-IR (LLVM)
llvm.init_targets()
context = llvm.context()
llvm_mod = llvm.to_module(mod, context)
proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
features = get_features(options)
triple = 'nvptx64-nvidia-cuda'
llvm.attach_datalayout(llvm_mod, triple, proc, features)
nvidia.set_nvvm_reflect_ftz(llvm_mod)
# Set maxnreg on all kernels, if it was provided.
if options.maxnreg is not None:
for k in llvm_mod.get_functions():
if not k.is_declaration() and k.is_external_linkage():
k.set_nvvm_maxnreg(options.maxnreg)
if options.extern_libs:
paths = [path for (name, path) in options.extern_libs]
llvm.link_extern_libs(llvm_mod, paths)
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
# Get some metadata
metadata["shared"] = src.get_int_attr("triton_gpu.shared")
ret = str(llvm_mod)
del llvm_mod
del context
return ret
经过一系列Pass降级就可以到LLVM IR
了
6、lower到ptx
以上均通过Python
的PassManager
去管理代码,从这里开始就不是了。third_party/nvidia/backend/compiler.py:277 make_ptx
如下所示,因为这部分代码走LLVM后端,不需要把他们引入到项目。
def make_ptx(src, metadata, opt, capability):
ptx_version = opt.ptx_version
if ptx_version is None:
_, cuda_version = _path_to_binary("ptxas")
ptx_version = ptx_get_version(cuda_version)
triple = 'nvptx64-nvidia-cuda'
proc = 'sm_90a' if capability == 90 else f'sm_{capability}'
features = get_features(opt)
ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
# Find kernel names (there should only be one)
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
assert len(names) == 1
metadata["name"] = names[0]
# post-process
ptx_version = f'{ptx_version//10}.{ptx_version%10}'
ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
# Remove the debug flag that prevents ptxas from optimizing the code
ret = re.sub(r",\s*debug|debug,\s*", "", ret)
if os.environ.get("NVPTX_ENABLE_DUMP", "0") == "1":
print("// -----// NVPTX Dump //----- //")
print(ret)
return ret
调用 llvm.translate_to_asm
做了lower,下面就是一些正则的替换。translate_to_asm
实际为 python/src/llvm.cc:42的translateLLVMIRToASM
函数。
7、生成cubin
最后生成二进制文件就大功告成了,用ptxas即可。代码如下所示
def make_cubin(src, metadata, opt, capability):
ptxas, _ = _path_to_binary("ptxas")
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
fsrc.write(src)
fsrc.flush()
fbin = fsrc.name + '.o'
line_info = [] if os.environ.get('TRITON_DISABLE_LINE_INFO') else ['-lineinfo']
fmad = [] if opt.enable_fp_fusion else ['--fmad=false']
suffix = 'a' if capability == 90 else ''
opt_level = ['--opt-level', '0'] if os.environ.get("DISABLE_PTXAS_OPT", "0") == "1" else []
ptxas_cmd = [
ptxas, *line_info, *fmad, '-v', *opt_level, f'--gpu-name=sm_{capability}{suffix}', fsrc.name, '-o', fbin
]
try:
subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)
if os.path.exists(fsrc.name):
os.remove(fsrc.name)
if os.path.exists(flog.name):
os.remove(flog.name)
except subprocess.CalledProcessError as e:
with open(flog.name) as log_file:
log = log_file.read()
if os.path.exists(flog.name):
os.remove(flog.name)
if e.returncode == 255:
error = 'Internal Triton PTX codegen error'
elif e.returncode == 128 + signal.SIGSEGV:
error = '`ptxas` raised SIGSEGV'
else:
error = f'`ptxas` failed with error code {e.returncode}'
raise RuntimeError(f'{error}\n'
f'`ptxas` stderr:\n{log}\n'
f'Repro command: {ptxas_cmd}\n')
with open(fbin, 'rb') as f:
cubin = f.read()
if os.path.exists(fbin):
os.remove(fbin)
return cubin
生成通过subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog)
转成二进制文件,编译的活终于干完了。
8、配合runtime运行
通过runtime跑起来也很简单,只需要把metadata元数据准备好再run就好了,python/triton/runtime/jit.py:676 launch kernel
代码如下所示
launch_metadata = kernel.launch_metadata(grid, stream, *non_constexpr_vals)
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
self.CompiledKernel.launch_enter_hook, self.CompiledKernel.launch_exit_hook, *non_constexpr_vals)
会调用到python/triton/compiler/compiler.py的runner,根据launch_enter_hook
和launch_exit_hook
这2个钩子来控制,最终调用的是third_party/nvidia/backend/driver.py的cuLaunchKernelExHandle
。
三、Triton运行示列
Triton自己提供了一些示例,flash-attention、xFormers、lightllm均对Triton做了支持。
1、Triton tutorials
上文的向量加就取于Triton的python/tutorials/01-vector-add.py,他还提供了包括fused-attention、fused-softmax、grouped-gemm在内的示例。
部分示例使用了@triton.autotune
,即自动优化参数针对不同的硬件选择最优的配置。其内的benchmark也和torch
内置的以及cuBLAS
等做了比较,基本达到了各有胜负的水平。此外triton还在积极得将arithmetic
和mathematics
融入进来,tl.math里有蛮多API了。
2、LightLLM
LightLLM是商汤的一个推理和服务框架,现在已经支持了不少模型。不同的模型会有自己的triton_kernel
,比如DeepSeek-V2、llama和mistral等等。
3、FlagGems
FlagGems是清华智源推出的通用算子库
,力求做到高性能
、广覆盖
和轻量级
,这里轻量级
是最重要的,他希望能够多款硬件支持,涵盖NVIDIA与国产芯片。
目前有以下厂商在适配FlagGems,图片来源 pages17
4、flash-attention
flash-attention
根据Triton
python/tutorials/06-fused-attention.py 也做了一个支持attention bias
的实现,flash_attn/flash_attn_triton.py。
诸如mpt等模型对其有使用,使用也很简单,from flash_attn.flash_attn_triton import flash_attn_func
import 进来直接使用 flash_attn_func
即可,详见源码attention.py:L172,不过给我提示了If not using a Prefix Language Model, we recommend setting "attn_impl" to "flash" instead of "triton"
,可能是其在前缀语言模型效果才好。
5、xFormers
xFormers
是facebookresearch
用于高效和灵活的 Transformer 模型的库,旨在提高深度学习框架的性能。
他们自己的dinov2一个计算机视觉自监督模型用了xFormers
的memory_efficient_attention。
参考
2.杨军 OpenAI Triton Conference参会随感兼谈Triton Hopper
3.官方文档
4.OpenAI Introducing Triton: Open-source GPU programming for neural networks
6.TanyoKwok 郭天佑 聊聊 PyTorch 2.0(Inductor)
8.科研败犬丶 OpenAI/Triton MLIR 第一章: Triton DSL
10.清华智源 基于Triton的大模型通用算子库FlagGems技术分享 Slides
11.清华智源 Triton中国生态Meetup(第一期)暨首次Triton中国社区活动 Slides
本文来自博客园,作者:暴力都不会的蒟蒻,转载请注明原文链接:https://www.cnblogs.com/BobHuang/p/18324040