triton 简要学习笔记

语法#

以最简单的向量相加为例, 通过把triton翻译成cuda的形式

@triton.jit   #需要加这行标识kernel
def add_kernel(x_ptr, 
               y_ptr, 
               output_ptr, 
               n_elements, 
               BLOCK_SIZE: tl.constexpr,
               ):
    
    pid = tl.program_id(axis=0)                      # 对应cuda中的blockIdx.x
    block_start = pid * BLOCK_SIZE                   # 对应 blockIdx.x * blockDim.x
    offsets = block_start + tl.arange(0, BLOCK_SIZE) # 对应tid = blockIdx.x * blockDim.x + threadIdx.x
    mask = offsets < n_elements                      # 对应 if (tid < n), 防止显存访问越界
    x = tl.load(x_ptr + offsets, mask=mask)          # x_ptr[tid] 访问显存获取数据
    y = tl.load(y_ptr + offsets, mask=mask)          # y_ptr[tid] 访问显存
    output = x + y                                   # auto output = x_ptr[tid] + y_ptr[tid]
    tl.store(output_ptr + offsets, output, mask=mask)  #output_ptr[tid] = output 写回全局显存

cuda对应关系#

线程/块/网格模型#

CUDA 概念 Triton 对应概念 说明
threadIdx.x tl.arange(0, BLOCK_SIZE) Triton 隐式管理线程,通过向量化操作替代显式线程索引
blockIdx.x tl.program_id(axis=0) 程序实例(逻辑线程块)的索引
blockDim.x BLOCK_SIZE(编译时常量) Triton 中通过 tl.constexpr 定义的块大小
gridDim.x tl.num_programs(axis=0) 网格中某一维度的总程序实例数

访存对应关系#

CUDA 内存类型 Triton 对应操作 说明
全局内存 (Global Memory) tl.load/tl.store 直接操作指针
共享内存 (Shared Memory) __shared__ tl.static_shared_array 静态分配的共享内存
寄存器 (Registers) Triton 自动管理 通过变量声明隐式使用

同步机制#

CUDA 同步 Triton 同步 说明
__syncthreads() tl.wait() 等待所有异步操作(如内存拷贝)完成
原子操作 (atomicAdd等) tl.atomic_add 等原子函数 Triton 提供原子操作的直接支持

kernel提交#

# triton
n = 1024
BLOCK_SIZE = 128
grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']), )
kernel[grid](input, output, n, BLOCK_SIZE=BLOCK_SIZE)

#对应的cuda
#add_kernel<<<(n + BLOCK_SIZE - 1)/BLOCK_SIZE, BLOCK_SIZE>>>(d_a, d_b, d_c, n);

CUB在triton中的对应实现#

https://triton-lang.org/main/python-api/triton.language.html

举例: 比如cub::BlockMergeSort-> tl.sort() cub::BlockScan -> tl::cumsum(), 可能不是完全等价, 待后续确认

性能调优#

nsight-compute#

可以profile triton_kernel, 牛的

ncu --kernel-name add_kernel -o add_new python3 triton_add.py

自带bench#

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  
        x_vals=[2**i for i in range(12, 28, 1)], 
        x_log=True,  
        line_arg='provider', 
        line_vals=['triton', 'torch'],  
        line_names=['Triton', 'Torch'], 
        styles=[('blue', '-'), ('green', '-')],  
        ylabel='GB/s',  
        plot_name='vector-add-performance',  
        args={}, 
    ))

def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
    gbps = lambda ms: 12 * size / ms * 1e-6
    return gbps(ms), gbps(max_ms), gbps(min_ms)

#benchmark.run(print_data=True, show_plots=True, save_path='./output')

作者:sunstrikes

出处:https://www.cnblogs.com/sunstrikes/p/18730203

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   SunStriKE  阅读(13)  评论(0编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~
点击右上角即可分享
微信分享提示
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu