Triton 源码初步研读
一、核心接口形态
def jit(
fn: Optional[T] = None,
*,
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
接口返回的是一个 JITFunction
对象,继承自 KernelInterface
class JITFunction(KernelInterface[T]):
# 略
class KernelInterface(Generic[T]):
run: T
def __getitem__(self, grid) -> T:
return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
JITFunction
调用时会有一个额外的参数 grid,类似 fn[grid](*args, **kwargs)
。
在 JITFunction
类实现中,核心的逻辑是 _make_launcher()
函数,内部会执行一个模版函数:
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug)
if not extern_libs is None:
key = (key, tuple(extern_libs.items()))
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
if callable(grid):
grid = grid({{{grid_args}}})
grid_size = len(grid)
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
if device is None:
device = get_current_device()
set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
try:
bin = cache[device][key] # <-------- 首先尝试从self.cache中获取当前grid已经编译后的 bin
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin
# kernel not cached -- compile
except KeyError:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
configs = self._get_config(*all_args),
constants = self._make_constants(constexpr_key)
constants.update({{i: None for i, arg in enumerate(all_args) if arg is None}})
constants.update({{i: 1 for i in configs[0].equal_to_1}})
# build kernel signature -- doesn't include specialized arguments
signature = {{ i: self._type_of(_key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs }}
# build stub signature -- includes arguments that are specialized
for i, arg in constants.items():
if callable(arg):
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs): # <-------- 如果缓存没有命中,则第一次指定 triton.compile
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug)
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
self.cache[device][key] = bin
return bin
return None
进一步我们来看 compile
函数的实现。其中第一个关键步骤是 make_stub
返回一个 so_path
。此部分分为两步:
- step 1:创建一个 main.c文件,将CUDA launcher相关的代码写入
- step 2:调用 _build函数编译出一个 .so文件缓存起来,并返回此
so_path
def make_stub(name, signature, constants):
# name of files that are cached
so_cache_key = make_so_cache_key(version_key(), signature, constants)
so_cache_manager = CacheManager(so_cache_key)
so_name = f"{name}.so"
# retrieve stub from cache if it exists
if not so_cache_manager.has_file(so_name):
with tempfile.TemporaryDirectory() as tmpdir:
src = generate_launcher(constants, signature)
src_path = os.path.join(tmpdir, "main.c") # <------- step 1
with open(src_path, "w") as f:
f.write(src)
so = _build(name, src_path, tmpdir) # <------- step 2
with open(so, "rb") as f:
so_cache_manager.put(f.read(), so_name, binary=True)
return so_cache_manager._make_path(so_name)
我们首先来看 step 1 都写了什么内容到 main.c
里,主要内容如下,从代码里可以看出,此处这里主要是启动相关的代码:cuLaunchKernel(function, XXXX)
#include \"cuda.h\"
#include <stdbool.h>
#include <Python.h>
static inline void gpuAssert(CUresult code, const char *file, int line)
{{
if (code != CUDA_SUCCESS)
{{
const char* prefix = "Triton Error [CUDA]: ";
const char* str;
cuGetErrorString(code, &str);
char err[1024] = {{0}};
strcat(err, prefix);
strcat(err, str);
PyErr_SetString(PyExc_RuntimeError, err);
}}
}}
#define CUDA_CHECK(ans) {{ gpuAssert((ans), __FILE__, __LINE__); }}
static void _launch(int gridX, int gridY, int gridZ, int num_warps, int shared_memory, CUstream stream, CUfunction function, {arg_decls}) {{
void *params[] = {{ {', '.join(f"&arg{i}" for i in signature.keys() if i not in constants)} }};
if(gridX*gridY*gridZ > 0){{
CUDA_CHECK(cuLaunchKernel(function, gridX, gridY, gridZ, 32*num_warps, 1, 1, shared_memory, stream, params, 0)); // <-------- 注意此处的 function 应该就是我们原生的函数
}}
}}
typedef struct _DevicePtrInfo {{
CUdeviceptr dev_ptr;
bool valid;
}} DevicePtrInfo;
static inline DevicePtrInfo getPointer(PyObject *obj, int idx) {{
DevicePtrInfo ptr_info;
ptr_info.dev_ptr = 0;
ptr_info.valid = true;
if (PyLong_Check(obj)) {{
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(obj);
return ptr_info;
}}
if (obj == Py_None) {{
// valid nullptr
return ptr_info;
}}
PyObject *ptr = PyObject_GetAttrString(obj, "data_ptr");
if(ptr){{
PyObject *empty_tuple = PyTuple_New(0);
PyObject *ret = PyObject_Call(ptr, empty_tuple, NULL);
Py_DECREF(empty_tuple);
Py_DECREF(ptr);
if (!PyLong_Check(ret)) {{
PyErr_SetString(PyExc_TypeError, "data_ptr method of Pointer object must return 64-bit int");
ptr_info.valid = false;
return ptr_info;
}}
ptr_info.dev_ptr = PyLong_AsUnsignedLongLong(ret);
if(!ptr_info.dev_ptr)
return ptr_info;
uint64_t dev_ptr;
int status = cuPointerGetAttribute(&dev_ptr, CU_POINTER_ATTRIBUTE_DEVICE_POINTER, ptr_info.dev_ptr);
if (status == CUDA_ERROR_INVALID_VALUE) {{
PyErr_Format(PyExc_ValueError,
"Pointer argument (at %d) cannot be accessed from Triton (cpu tensor?)", idx);
ptr_info.valid = false;
}}
ptr_info.dev_ptr = dev_ptr;
Py_DECREF(ret); // Thanks ChatGPT!
return ptr_info;
}}
PyErr_SetString(PyExc_TypeError, "Pointer argument must be either uint64 or have data_ptr method");
return ptr_info;
}}
static PyObject* launch(PyObject* self, PyObject* args) {{
int gridX, gridY, gridZ;
uint64_t _stream;
uint64_t _function;
int num_warps;
int shared_memory;
PyObject *launch_enter_hook = NULL;
PyObject *launch_exit_hook = NULL;
PyObject *compiled_kernel = NULL;
{' '.join([f"{_extracted_type(ty)} _arg{i}; " for i, ty in signature.items()])}
if(!PyArg_ParseTuple(args, \"{format}\", &gridX, &gridY, &gridZ, &num_warps, &shared_memory, &_stream, &_function, &launch_enter_hook, &launch_exit_hook, &compiled_kernel, {', '.join(f"&_arg{i}" for i, ty in signature.items())})) {{
return NULL;
}}
if (launch_enter_hook != Py_None) {{
PyObject_CallObject(launch_enter_hook, args);
}}
// raise exception asap
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
_launch(gridX, gridY, gridZ, num_warps, shared_memory, (CUstream)_stream, (CUfunction)_function, {', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items())});
if (launch_exit_hook != Py_None) {{
PyObject_CallObject(launch_exit_hook, args);
}}
if(PyErr_Occurred()) {{
return NULL;
}}
// return None
Py_INCREF(Py_None);
return Py_None;
}}
static PyMethodDef ModuleMethods[] = {{
{{"launch", launch, METH_VARARGS, "Entry point for all kernels with this signature"}}, # <-------- 与 CompiledKernel 中的 self.c_wrapper = getattr(mod, "launch") 相呼应
{{NULL, NULL, 0, NULL}} // sentinel
}};
static struct PyModuleDef ModuleDef = {{
PyModuleDef_HEAD_INIT,
\"__triton_launcher\",
NULL, //documentation
-1, //size
ModuleMethods
}};
PyMODINIT_FUNC PyInit___triton_launcher(void) {{
PyObject *m = PyModule_Create(&ModuleDef);
if(m == NULL) {{
return NULL;
}}
PyModule_AddFunctions(m, ModuleMethods);
return m;
}}
接着 step 2 做的事情就是创建一个子进程,调用 CC 命令+ setuptools工具编译成一个 .so
文件。(注意:setuptools 似乎是一个兜底策略,只有在 CC 编译失败才会走到。)
def _build(name, src, srcdir):
cuda_lib_dirs = libcuda_dirs()
base_dir = os.path.join(os.path.dirname(__file__), os.path.pardir)
cuda_path = os.path.join(base_dir, "third_party", "cuda")
cu_include_dir = os.path.join(cuda_path, "include")
triton_include_dir = os.path.join(os.path.dirname(__file__), "include")
cuda_header = os.path.join(cu_include_dir, "cuda.h")
triton_cuda_header = os.path.join(triton_include_dir, "cuda.h")
cu_include_dir = triton_include_dir
cc = os.environ.get("CC")
cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so]
cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs]
ret = subprocess.check_call(cc_cmd)
# fallback on setuptools
extra_compile_args = []
library_dirs = cuda_lib_dirs
include_dirs = [srcdir, cu_include_dir]
libraries = ['cuda']
# extra arguments
extra_link_args = []
# create extension module
ext = setuptools.Extension(
name=name,
language='c',
sources=[src],
include_dirs=include_dirs,
extra_compile_args=extra_compile_args + ['-O3'],
extra_link_args=extra_link_args,
library_dirs=library_dirs,
libraries=libraries,
)
# build extension module
args = ['build_ext']
args.append('--build-temp=' + srcdir)
args.append('--build-lib=' + srcdir)
args.append('-q')
args = dict(
name=name,
ext_modules=[ext],
script_args=args,
)
with quiet():
setuptools.setup(**args)
return so
然后开始执行 compile
的第二阶段的工作,主要是IR的几个 stage 的转换:
def compile(xxx):
# 省略
asm = dict()
module = fn
first_stage = list(stages.keys()).index(ext)
# run compilation pipeline and populate metadata
for ir, (parse, compile_kernel) in list(stages.items())[first_stage:]:
# stage 1: ast_to_trir
if ir == ext:
next_module = parse(fn)
else:
# stage 2,3,4
next_module = compile_kernel(module)
if ir == "llir" and "shared" not in metadata:
metadata["shared"] = _triton.get_shared_memory_size(module)
if ir == "ptx":
metadata["name"] = get_kernel_name(next_module, pattern='// .globl')
if ir == "cubin":
asm[ir] = next_module
else:
asm[ir] = str(next_module)
module = next_module
几个重要阶段的IR转换分别对应如下函数:
ast_to_ttir
:其中CodeGenerator(ast.NodeVisitor)
负责主要工作,返回的是一个 C++ 端 Pybind 的mlir::ModuleOp
对象ttir_to_ttgir
:其中 pm变量是一个 C++ 端Pybind 的mlir::PassManager
对象ttgir_to_llir
:其中核心函数gpu_to_llvmir
也是C++端 Pybind 的translateTritonGPUToLLVMIR
函数llir_to_ptx
:其中核心函数translate_llvmir_to_ptx
也是主要由C++端的translateLLVMIRToPT
函数来是实现ptx_to_cubin
:其中核心函数compile_ptx_to_cubin
主要由 C++ 端来实现的
def ast_to_ttir(fn, signature, specialization, constants, debug):
context = ir.context()
context.load_triton()
prototype = language.function_type([], arg_types)
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants,
function_name=function_name, attributes=new_attrs,
is_kernel=True, debug=debug)
generator.visit(fn.parse())
ret = generator.module # <------- 由 ir.builder(context).create_module()
# module takes ownership of the context
ret.context = context
return ret
def ttir_to_ttgir(mod, num_warps):
pm = _triton.ir.pass_manager(mod.context)
pm.add_convert_triton_to_tritongpu_pass(num_warps)
pm.run(mod)
return mod
def ttgir_to_llir(mod, extern_libs, arch):
if extern_libs:
_add_external_libs(mod, extern_libs)
# TODO: separate tritongpu_to_llvmir for different backends
if _is_cuda(arch):
return _triton.translate_triton_gpu_to_llvmir(mod, arch, False)
else:
return _triton.translate_triton_gpu_to_llvmir(mod, 0, True)
def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str:
if ptx_version is None:
_, cuda_version = path_to_ptxas()
ptx_version = ptx_get_version(cuda_version)
return _triton.translate_llvmir_to_ptx(mod, arch, ptx_version)
def ptx_to_cubin(ptx: str, arch: int):
ptxas, _ = path_to_ptxas()
return _triton.compile_ptx_to_cubin(ptx, ptxas, arch)
这里额外贴一下 ttgir_to_llir
中的核心函数代码:
std::unique_ptr<llvm::Module>
translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
mlir::ModuleOp module, int computeCapability,
bool isROCM) {
mlir::PassManager pm(module->getContext());
applyPassManagerCLOptions(pm);
auto printingFlags = mlir::OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
pm.enableIRPrinting(
/*shouldPrintBeforePass=*/nullptr,
/*shouldPrintAfterPass=*/
[](mlir::Pass *pass, mlir::Operation *) {
return ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
},
/*printModuleScope=*/false,
/*printAfterOnlyOnChange=*/true,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability, isROCM));
pm.addPass(mlir::createArithToLLVMConversionPass());
pm.addPass(mlir::createCanonicalizerPass());
// Simplify the IR
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createSymbolDCEPass());
if (failed(pm.run(module))) {
llvm::errs() << "Pass execution failed";
return nullptr;
}
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module, isROCM);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";
return nullptr;
}
if (::triton::tools::getBoolEnv("LLVM_IR_ENABLE_DUMP")) {
std::string mod_string;
std::unique_ptr<llvm::raw_string_ostream> ir_ss(
new llvm::raw_string_ostream(mod_string));
llvmIR->print(*ir_ss, nullptr);
std::cout << "// -----// LLVM IR Dump //----- //\n"
<< mod_string << std::endl;
}
return llvmIR;
}
其中 translateLLVMIRToPTX
的核心代码是:
std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
// LLVM version in use may not officially support target hardware.
// Supported versions for LLVM 14 are here:
// https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/clang/include/clang/Basic/BuiltinsNVPTX.def
int maxPTX = std::min(80, version);
int maxCC = std::min(90, cc);
// options
auto options = llvm::cl::getRegisteredOptions();
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = maxPTX / 10;
int ptxMinor = maxPTX % 10;
// create
std::string triple = "nvptx64-nvidia-cuda";
std::string proc = "sm_" + std::to_string(maxCC);
std::string layout = "";
std::string features = "";
// std::string features = "+ptx" + std::to_string(maxPTX);
initLLVM();
// verify and store llvm
llvm::legacy::PassManager pm;
pm.add(llvm::createVerifierPass());
pm.run(module);
// module->print(llvm::outs(), nullptr);
// create machine
module.setTargetTriple(triple);
std::string error;
auto target =
llvm::TargetRegistry::lookupTarget(module.getTargetTriple(), error);
llvm::TargetOptions opt;
opt.AllowFPOpFusion = llvm::FPOpFusion::Fast;
opt.UnsafeFPMath = false;
opt.NoInfsFPMath = false;
opt.NoNaNsFPMath = true;
llvm::TargetMachine *machine = target->createTargetMachine(
module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_,
std::nullopt, llvm::CodeGenOpt::Aggressive);
// set data layout
if (layout.empty())
module.setDataLayout(machine->createDataLayout());
else
module.setDataLayout(layout);
// emit machine code
std::string result;
{
llvm::raw_string_ostream stream(result);
llvm::buffer_ostream pstream(stream);
for (llvm::Function &f : module.functions())
f.addFnAttr(llvm::Attribute::AlwaysInline);
llvm::legacy::PassManager pass;
// emit
machine->addPassesToEmitFile(pass, pstream, nullptr,
llvm::CodeGenFileType::CGFT_AssemblyFile);
pass.run(module);
}
// post-process
findAndReplace(result, ".version", "\n",
".version " + std::to_string(ptxMajor) + "." +
std::to_string(ptxMinor) + "\n");
findAndReplace(result, ".target", "\n", ".target " + sm + "\n");
while (findAndReplace(result, "\t// begin inline asm", "\n", ""))
;
while (findAndReplace(result, "\t// end inline asm", "\n", ""))
;
return result;
}
如此已经得到了cubin文件,接下来是 compile 函数的最后一步:返回编译后的 CompiledKernel(fn, so_path, metadata, asm)
handle对象。主要逻辑由CompiledKernel
来实现。
在最前面 jit函数里最后负责执行的逻辑语句是:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
也就是CompiledKernel.c_wrapper 函数:
class CompiledKernel:
def __init__(self, fn, so_path, metadata, asm):
# initialize launcher
import importlib.util
spec = importlib.util.spec_from_file_location("__triton_launcher", so_path)
mod = importlib.util.module_from_spec(spec)
self.fn = fn
spec.loader.exec_module(mod)
self.c_wrapper = getattr(mod, "launch") # <-------- 这个家伙
这里的 so_name 就是 make_stub 中编译出来的 Launcher 相关的 .so
的路径。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· winform 绘制太阳,地球,月球 运作规律
· AI与.NET技术实操系列(五):向量存储与相似性搜索在 .NET 中的实现
· 超详细:普通电脑也行Windows部署deepseek R1训练数据并当服务器共享给他人
· 【硬核科普】Trae如何「偷看」你的代码?零基础破解AI编程运行原理
· 上周热点回顾(3.3-3.9)