Skip to content

Commit 6bc0661

Browse files
danzimmzwu-2025
authored andcommitted
[typehint][python] Unify compile_module_from_src from nvidia + amd backends, make more defensive (triton-lang#6775)
I started looking at adding typehints to python/triton/backends and noticed there's a bit of duplicated code across the nvidia and amd implementations. To start, I think we can unify the `compile_module_from_src` since it appears the implementations are identical. Additionally, I added some extra defensive checks to the implementation in case the cache returns a faulty artifact. Ideally this shouldn't happen, but figure better to have a cache miss than a crash. We can remove this if it seems superfluous. I plan to add some tests, but wanted to open the PR ahead of time for visibility (tests will be verifying compilation / loading succeeds & that fallback to compiling with bad artifact succeeds)
1 parent f28654f commit 6bc0661

File tree

7 files changed

+168
-69
lines changed

7 files changed

+168
-69
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ files = [
99
"python/triton/runtime/build.py",
1010
"python/triton/_utils.py",
1111
"python/test/unit/test_knobs.py",
12+
"python/test/unit/runtime/test_build.py",
1213
"python/test/unit/runtime/test_compilation_listener.py",
1314
]
1415
exclude = ["/build/"]
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
import pytest
4+
import tempfile
5+
6+
from pathlib import Path
7+
8+
import triton
9+
10+
from triton.runtime.build import compile_module_from_src
11+
12+
TEST_MODULE_C = """
13+
#include <Python.h>
14+
#include <string.h>
15+
16+
static PyObject* go(PyObject* self, PyObject* args) {
17+
const char *command;
18+
if (!PyArg_ParseTuple(args, "s", &command))
19+
return NULL;
20+
21+
const char* res;
22+
if (strcmp(command, "hello") == 0) {
23+
res = "hiya";
24+
} else {
25+
res = "huh";
26+
}
27+
return PyUnicode_FromString(res);
28+
}
29+
30+
static PyMethodDef ModuleMethods[] = {
31+
{"go", go, METH_VARARGS, "test_module.go for testing"},
32+
{NULL, NULL, 0, NULL}
33+
};
34+
35+
static struct PyModuleDef ModuleDef = {
36+
PyModuleDef_HEAD_INIT,
37+
"test_module",
38+
NULL, //documentation
39+
-1, //size
40+
ModuleMethods
41+
};
42+
43+
PyMODINIT_FUNC PyInit_test_module(void) {
44+
PyObject *m = PyModule_Create(&ModuleDef);
45+
if(m == NULL) {
46+
return NULL;
47+
}
48+
PyModule_AddFunctions(m, ModuleMethods);
49+
return m;
50+
}
51+
"""
52+
53+
54+
def test_compile_module(fresh_triton_cache):
55+
mod = compile_module_from_src(TEST_MODULE_C, "test_module")
56+
57+
with pytest.raises(Exception):
58+
mod.go()
59+
60+
assert mod.go("huh") == "huh"
61+
assert mod.go("hello") == "hiya"
62+
63+
# Make sure the module is cached
64+
mod2 = compile_module_from_src(TEST_MODULE_C, "test_module")
65+
assert mod2.__file__ == mod.__file__
66+
67+
68+
def test_compile_module_bad_cache(fresh_knobs_except_libraries):
69+
with tempfile.TemporaryDirectory() as tmpd:
70+
tmp = Path(tmpd)
71+
called_get_file = False
72+
73+
class InvalidFileCacheManager(triton.runtime.cache.FileCacheManager):
74+
75+
def get_file(self, filename: str) -> str | None:
76+
nonlocal called_get_file
77+
called_get_file = True
78+
(tmp / filename).write_text("not an so")
79+
return str(tmp / filename)
80+
81+
# First corrupt the cache
82+
fresh_knobs_except_libraries.cache.manager_class = InvalidFileCacheManager
83+
84+
mod = compile_module_from_src(TEST_MODULE_C, "test_module")
85+
assert called_get_file
86+
87+
with pytest.raises(Exception):
88+
mod.go()
89+
90+
assert mod.go("huh") == "huh"
91+
assert mod.go("hello") == "hiya"

python/test/unit/tools/test_aot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import triton
1010
from triton.backends.compiler import GPUTarget
11-
from triton.backends.nvidia.driver import include_dir, library_dirs
11+
from triton.backends.nvidia.driver import include_dirs, library_dirs
1212

1313
kernel_utils_src = """
1414
import triton
@@ -100,7 +100,7 @@ def kernel(C, A, B, M, N, K,
100100
def gen_kernel_library(dir, libname):
101101
c_files = glob.glob(os.path.join(dir, "*.c"))
102102
subprocess.run(
103-
["gcc"] + c_files + ["-I", include_dir[0], "-c", "-fPIC"],
103+
["gcc"] + c_files + ["-I", include_dirs[0], "-c", "-fPIC"],
104104
check=True,
105105
cwd=dir,
106106
)
@@ -175,7 +175,7 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0):
175175
file.write(src)
176176

177177
command = ["gcc", "test.c"]
178-
for inc_dir in include_dir:
178+
for inc_dir in include_dirs:
179179
command.extend(["-I", inc_dir])
180180
for lib_dir in library_dirs():
181181
command.extend(["-L", lib_dir])

python/triton/backends/driver.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
1-
import functools
21
from abc import ABCMeta, abstractmethod
32
from typing import Callable, List, Protocol, Sequence
43

54

6-
@functools.lru_cache()
7-
def platform_key():
8-
from platform import machine, system, architecture
9-
return ",".join([machine(), system(), *architecture()])
10-
11-
125
class Benchmarker(Protocol):
136

147
def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]:

python/triton/runtime/build.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1-
import sysconfig
1+
from __future__ import annotations
2+
3+
import functools
4+
import hashlib
5+
import importlib.util
6+
import logging
27
import os
38
import shutil
49
import subprocess
10+
import sysconfig
11+
import tempfile
12+
13+
from types import ModuleType
514

15+
from .cache import get_cache_manager
616
from .. import knobs
717

818

@@ -40,3 +50,43 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
4050
cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
4151
subprocess.check_call(cc_cmd, stdout=subprocess.DEVNULL)
4252
return so
53+
54+
55+
@functools.lru_cache
56+
def platform_key() -> str:
57+
from platform import machine, system, architecture
58+
return ",".join([machine(), system(), *architecture()])
59+
60+
61+
def _load_module_from_path(name: str, path: str) -> ModuleType:
62+
spec = importlib.util.spec_from_file_location(name, path)
63+
if not spec or not spec.loader:
64+
raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
65+
mod = importlib.util.module_from_spec(spec)
66+
spec.loader.exec_module(mod)
67+
return mod
68+
69+
70+
def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
71+
include_dirs: list[str] | None = None, libraries: list[str] | None = None) -> ModuleType:
72+
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
73+
cache = get_cache_manager(key)
74+
suffix = sysconfig.get_config_var("EXT_SUFFIX")
75+
cache_path = cache.get_file(f"{name}{suffix}")
76+
77+
if cache_path is not None:
78+
try:
79+
return _load_module_from_path(name, cache_path)
80+
except (RuntimeError, ImportError):
81+
log = logging.getLogger(__name__)
82+
log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
83+
84+
with tempfile.TemporaryDirectory() as tmpdir:
85+
src_path = os.path.join(tmpdir, name + ".c")
86+
with open(src_path, "w") as f:
87+
f.write(src)
88+
so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [])
89+
with open(so, "rb") as f:
90+
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
91+
92+
return _load_module_from_path(name, cache_path)

third_party/amd/backend/driver.py

Lines changed: 5 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
import functools
22
import os
3-
import hashlib
43
import subprocess
5-
import sysconfig
6-
import tempfile
74
import re
85
from pathlib import Path
9-
from triton.runtime.build import _build
106
from triton import knobs
11-
from triton.runtime.cache import get_cache_manager
127
from triton.backends.compiler import GPUTarget
13-
from triton.backends.driver import GPUDriver, platform_key
8+
from triton.backends.driver import GPUDriver
9+
from triton.runtime.build import compile_module_from_src
1410
from triton.tools.tensor_descriptor import TensorDescriptor
1511

1612
dirname = os.path.dirname(os.path.realpath(__file__))
17-
include_dir = [os.path.join(dirname, "include")]
13+
include_dirs = [os.path.join(dirname, "include")]
1814

1915

2016
def _find_already_mmapped_dylib_on_linux(lib_name):
@@ -133,26 +129,6 @@ def _get_path_to_hip_runtime_dylib():
133129
raise RuntimeError(f"cannot locate {lib_name} after attempted paths {paths}")
134130

135131

136-
def compile_module_from_src(src, name):
137-
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
138-
cache = get_cache_manager(key)
139-
suffix = sysconfig.get_config_var("EXT_SUFFIX")
140-
cache_path = cache.get_file(f"{name}{suffix}")
141-
if cache_path is None:
142-
with tempfile.TemporaryDirectory() as tmpdir:
143-
src_path = os.path.join(tmpdir, "main.c")
144-
with open(src_path, "w") as f:
145-
f.write(src)
146-
so = _build(name, src_path, tmpdir, [], include_dir, [])
147-
with open(so, "rb") as f:
148-
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
149-
import importlib.util
150-
spec = importlib.util.spec_from_file_location(name, cache_path)
151-
mod = importlib.util.module_from_spec(spec)
152-
spec.loader.exec_module(mod)
153-
return mod
154-
155-
156132
class HIPUtils(object):
157133

158134
def __new__(cls):
@@ -167,7 +143,7 @@ def __init__(self):
167143
# This way we don't need to escape-quote C code curly brackets and we can replace
168144
# exactly once.
169145
src = src.replace('/*py_libhip_search_path*/', libhip_path, 1)
170-
mod = compile_module_from_src(src, "hip_utils")
146+
mod = compile_module_from_src(src=src, name="hip_utils", include_dirs=include_dirs)
171147
self.load_binary = mod.load_binary
172148
self.get_device_properties = mod.get_device_properties
173149

@@ -560,7 +536,7 @@ def __init__(self, src, metadata):
560536
constants = {arg_idx(idx): value for idx, value in constants.items()}
561537
signature = {idx: value for idx, value in src.signature.items()}
562538
src = make_launcher(constants, signature, metadata.warp_size)
563-
mod = compile_module_from_src(src, "__triton_launcher")
539+
mod = compile_module_from_src(src=src, name="__triton_launcher", include_dirs=include_dirs)
564540
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
565541

566542
self.launch = wrap_handle_tensor_descriptor(mod.launch) if has_tensor_desc_arg else mod.launch

third_party/nvidia/backend/driver.py

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
11
import functools
22
import operator
33
import os
4-
import sysconfig
5-
import hashlib
64
import subprocess
7-
import tempfile
85
import triton
96
import re
107
from pathlib import Path
118
from triton import knobs
12-
from triton.runtime.build import _build
13-
from triton.runtime.cache import get_cache_manager
9+
from triton.runtime.build import compile_module_from_src
1410
from triton.runtime import _allocation
1511
from triton.backends.compiler import GPUTarget
16-
from triton.backends.driver import GPUDriver, platform_key
12+
from triton.backends.driver import GPUDriver
1713

1814
from triton.tools.tensor_descriptor import TensorDescriptor
1915

2016
dirname = os.path.dirname(os.path.realpath(__file__))
21-
include_dir = [os.path.join(dirname, "include")]
17+
include_dirs = [os.path.join(dirname, "include")]
2218
libdevice_dir = os.path.join(dirname, "lib")
2319
libraries = ['cuda']
2420

@@ -52,26 +48,6 @@ def library_dirs():
5248
return [libdevice_dir, *libcuda_dirs()]
5349

5450

55-
def compile_module_from_src(src, name):
56-
key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
57-
cache = get_cache_manager(key)
58-
suffix = sysconfig.get_config_var("EXT_SUFFIX")
59-
cache_path = cache.get_file(f"{name}{suffix}")
60-
if cache_path is None:
61-
with tempfile.TemporaryDirectory() as tmpdir:
62-
src_path = os.path.join(tmpdir, "main.c")
63-
with open(src_path, "w") as f:
64-
f.write(src)
65-
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
66-
with open(so, "rb") as f:
67-
cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
68-
import importlib.util
69-
spec = importlib.util.spec_from_file_location(name, cache_path)
70-
mod = importlib.util.module_from_spec(spec)
71-
spec.loader.exec_module(mod)
72-
return mod
73-
74-
7551
# ------------------------
7652
# Utils
7753
# ------------------------
@@ -85,7 +61,13 @@ def __new__(cls):
8561
return cls.instance
8662

8763
def __init__(self):
88-
mod = compile_module_from_src(Path(os.path.join(dirname, "driver.c")).read_text(), "cuda_utils")
64+
mod = compile_module_from_src(
65+
src=Path(os.path.join(dirname, "driver.c")).read_text(),
66+
name="cuda_utils",
67+
library_dirs=library_dirs(),
68+
include_dirs=include_dirs,
69+
libraries=libraries,
70+
)
8971
self.load_binary = mod.load_binary
9072
self.get_device_properties = mod.get_device_properties
9173
self.cuOccupancyMaxActiveClusters = mod.cuOccupancyMaxActiveClusters
@@ -643,7 +625,13 @@ def __init__(self, src, metadata):
643625
signature = {idx: value for idx, value in src.signature.items()}
644626
tensordesc_meta = getattr(metadata, "tensordesc_meta", None)
645627
src = make_launcher(constants, signature, tensordesc_meta)
646-
mod = compile_module_from_src(src, "__triton_launcher")
628+
mod = compile_module_from_src(
629+
src=src,
630+
name="__triton_launcher",
631+
library_dirs=library_dirs(),
632+
include_dirs=include_dirs,
633+
libraries=libraries,
634+
)
647635
has_tensor_desc_arg = any(isinstance(sig, str) and sig.startswith("tensordesc") for sig in signature.values())
648636

649637
self.num_ctas = functools.reduce(operator.mul, metadata.cluster_dims, 1)

0 commit comments

Comments
 (0)