Skip to content

Commit d31ccfe

Browse files
authored
[NFC] Use get_config_var('EXT_SUFFIX') instead of using so directly (#4958)
Change to improve platform independence. How it works? On Windows: ```python >>> import sysconfig >>> sysconfig.get_config_var("EXT_SUFFIX") '.cp310-win_amd64.pyd' >>> sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] 'pyd' ``` On Linux: ```python >>> import sysconfig >>> sysconfig.get_config_var("EXT_SUFFIX") '.cpython-310-x86_64-linux-gnu.so' >>> sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] 'so' ``` --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 152ef2d commit d31ccfe

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

python/triton/compiler/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import functools
1717
import os
18+
import sysconfig
1819

1920
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
2021
# and any following whitespace
@@ -151,7 +152,8 @@ def triton_key():
151152

152153
# backend
153154
libtriton_hash = hashlib.sha256()
154-
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
155+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
156+
with open(os.path.join(TRITON_PATH, f"_C/libtriton.{ext}"), "rb") as f:
155157
while True:
156158
chunk = f.read(1024**2)
157159
if not chunk:

third_party/nvidia/backend/driver.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import os
3+
import sysconfig
34
import hashlib
45
import subprocess
56
import tempfile
@@ -48,15 +49,16 @@ def library_dirs():
4849
def compile_module_from_src(src, name):
4950
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
5051
cache = get_cache_manager(key)
51-
cache_path = cache.get_file(f"{name}.so")
52+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
53+
cache_path = cache.get_file(f"{name}.{ext}")
5254
if cache_path is None:
5355
with tempfile.TemporaryDirectory() as tmpdir:
5456
src_path = os.path.join(tmpdir, "main.c")
5557
with open(src_path, "w") as f:
5658
f.write(src)
5759
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
5860
with open(so, "rb") as f:
59-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
61+
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
6062
import importlib.util
6163
spec = importlib.util.spec_from_file_location(name, cache_path)
6264
mod = importlib.util.module_from_spec(spec)

0 commit comments

Comments
 (0)