triton 简要学习笔记
语法#
以最简单的向量相加为例, 通过把triton翻译成cuda的形式
@triton.jit #需要加这行标识kernel
def add_kernel(x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) # 对应cuda中的blockIdx.x
block_start = pid * BLOCK_SIZE # 对应 blockIdx.x * blockDim.x
offsets = block_start + tl.arange(0, BLOCK_SIZE) # 对应tid = blockIdx.x * blockDim.x + threadIdx.x
mask = offsets < n_elements # 对应 if (tid < n), 防止显存访问越界
x = tl.load(x_ptr + offsets, mask=mask) # x_ptr[tid] 访问显存获取数据
y = tl.load(y_ptr + offsets, mask=mask) # y_ptr[tid] 访问显存
output = x + y # auto output = x_ptr[tid] + y_ptr[tid]
tl.store(output_ptr + offsets, output, mask=mask) #output_ptr[tid] = output 写回全局显存
cuda对应关系#
线程/块/网格模型#
CUDA 概念 | Triton 对应概念 | 说明 |
---|---|---|
threadIdx.x |
tl.arange(0, BLOCK_SIZE) |
Triton 隐式管理线程,通过向量化操作替代显式线程索引 |
blockIdx.x |
tl.program_id(axis=0) |
程序实例(逻辑线程块)的索引 |
blockDim.x |
BLOCK_SIZE (编译时常量) |
Triton 中通过 tl.constexpr 定义的块大小 |
gridDim.x |
tl.num_programs(axis=0) |
网格中某一维度的总程序实例数 |
访存对应关系#
CUDA 内存类型 | Triton 对应操作 | 说明 |
---|---|---|
全局内存 (Global Memory) | tl.load /tl.store |
直接操作指针 |
共享内存 (Shared Memory) __shared__ |
tl.static_shared_array |
静态分配的共享内存 |
寄存器 (Registers) | Triton 自动管理 | 通过变量声明隐式使用 |
同步机制#
CUDA 同步 | Triton 同步 | 说明 |
---|---|---|
__syncthreads() |
tl.wait() |
等待所有异步操作(如内存拷贝)完成 |
原子操作 (atomicAdd 等) |
tl.atomic_add 等原子函数 |
Triton 提供原子操作的直接支持 |
kernel提交#
# triton
n = 1024
BLOCK_SIZE = 128
grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']), )
kernel[grid](input, output, n, BLOCK_SIZE=BLOCK_SIZE)
#对应的cuda
#add_kernel<<<(n + BLOCK_SIZE - 1)/BLOCK_SIZE, BLOCK_SIZE>>>(d_a, d_b, d_c, n);
CUB在triton中的对应实现#
https://triton-lang.org/main/python-api/triton.language.html
举例: 比如cub::BlockMergeSort-> tl.sort() cub::BlockScan -> tl::cumsum(), 可能不是完全等价, 待后续确认
性能调优#
nsight-compute#
可以profile triton_kernel, 牛的
ncu --kernel-name add_kernel -o add_new python3 triton_add.py
自带bench#
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'],
x_vals=[2**i for i in range(12, 28, 1)],
x_log=True,
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'Torch'],
styles=[('blue', '-'), ('green', '-')],
ylabel='GB/s',
plot_name='vector-add-performance',
args={},
))
def benchmark(size, provider):
x = torch.rand(size, device='cuda', dtype=torch.float32)
y = torch.rand(size, device='cuda', dtype=torch.float32)
quantiles = [0.5, 0.2, 0.8]
if provider == 'torch':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
gbps = lambda ms: 12 * size / ms * 1e-6
return gbps(ms), gbps(max_ms), gbps(min_ms)
#benchmark.run(print_data=True, show_plots=True, save_path='./output')
作者:sunstrikes
出处:https://www.cnblogs.com/sunstrikes/p/18730203
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】凌霞软件回馈社区,博客园 & 1Panel & Halo 联合会员上线
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】博客园社区专享云产品让利特惠,阿里云新客6.5折上折
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek “源神”启动!「GitHub 热点速览」
· 我与微信审核的“相爱相杀”看个人小程序副业
· 微软正式发布.NET 10 Preview 1:开启下一代开发框架新篇章
· C# 集成 DeepSeek 模型实现 AI 私有化(本地部署与 API 调用教程)
· spring官宣接入deepseek,真的太香了~