Skip to content

Commit 9d87bfc

Browse files
jammmwoct0rdho
authored andcommitted
Add Windows/clang-cl support for AMD HIP backend
- Use LoadLibrary/GetProcAddress on Windows instead of dlopen/dlsym - Use rocm_sdk.find_libraries() to locate amdhip64 - Add platform-specific macros for dynamic library loading - Escape Windows paths for C string embedding - Treat clang-cl as MSVC-compatible compiler in build.py - Fix NamedTemporaryFile handling on Windows in compiler.py
1 parent 71840cf commit 9d87bfc

File tree

4 files changed

+232
-67
lines changed

4 files changed

+232
-67
lines changed

python/triton/runtime/build.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,26 @@ def is_msvc(cc):
5151
return cc == "cl" or cc == "cl.exe"
5252

5353

54+
def is_clang_cl(cc):
55+
cc = os.path.basename(cc).lower()
56+
return cc == "clang-cl" or cc == "clang-cl.exe"
57+
58+
5459
def is_clang(cc):
5560
cc = os.path.basename(cc).lower()
5661
return cc == "clang" or cc == "clang.exe"
5762

5863

5964
def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
6065
ccflags: list[str]) -> list[str]:
61-
if is_msvc(cc):
66+
if is_msvc(cc) or is_clang_cl(cc):
6267
out_base = os.path.splitext(out)[0]
63-
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/std:c11", "/wd4819"]
68+
cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819"]
69+
# clang-cl doesn't support /std:c11, use -std=c11 instead
70+
if is_clang_cl(cc):
71+
cc_cmd += ["-std=c11"]
72+
else:
73+
cc_cmd += ["/std:c11"]
6474
cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
6575
cc_cmd += [f"/Fo{out_base + '.obj'}"]
6676
cc_cmd += ["/link"]
@@ -110,7 +120,7 @@ def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_di
110120
if sysconfig.get_config_var("Py_GIL_DISABLED"):
111121
version += "t"
112122
libraries = libraries + [f"python{version}"]
113-
if is_msvc(cc):
123+
if is_msvc(cc) or is_clang_cl(cc):
114124
_, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
115125
include_dirs = include_dirs + msvc_winsdk_inc_dirs
116126
library_dirs = library_dirs + msvc_winsdk_lib_dirs

third_party/amd/backend/compiler.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
from typing import Any, Dict, Tuple
66
from types import ModuleType
77
import hashlib
8+
import os
9+
import platform
810
import tempfile
911
import re
1012
import functools
1113
import warnings
1214
from pathlib import Path
1315

16+
def _is_windows():
17+
return platform.system() == 'Windows'
18+
1419

1520
def get_min_dot_size(target: GPUTarget):
1621
# We fallback to use FMA and cast arguments if certain configurations is
@@ -437,13 +442,35 @@ def make_hsaco(src, metadata, options):
437442
if knobs.compilation.enable_asan:
438443
target_features = '+xnack'
439444
hsaco = amd.assemble_amdgcn(src, options.arch, target_features)
440-
with tempfile.NamedTemporaryFile() as tmp_out:
441-
with tempfile.NamedTemporaryFile() as tmp_in:
442-
with open(tmp_in.name, "wb") as fd_in:
443-
fd_in.write(hsaco)
445+
# On Windows, NamedTemporaryFile cannot be reopened while open, so we
446+
# use delete=False and manually clean up.
447+
if _is_windows():
448+
tmp_in = tempfile.NamedTemporaryFile(delete=False, suffix='.o')
449+
tmp_out = tempfile.NamedTemporaryFile(delete=False, suffix='.hsaco')
450+
try:
451+
tmp_in.write(hsaco)
452+
tmp_in.close()
453+
tmp_out.close()
444454
amd.link_hsaco(tmp_in.name, tmp_out.name)
445-
with open(tmp_out.name, "rb") as fd_out:
446-
ret = fd_out.read()
455+
with open(tmp_out.name, "rb") as fd_out:
456+
ret = fd_out.read()
457+
finally:
458+
try:
459+
os.unlink(tmp_in.name)
460+
except OSError:
461+
pass
462+
try:
463+
os.unlink(tmp_out.name)
464+
except OSError:
465+
pass
466+
else:
467+
with tempfile.NamedTemporaryFile() as tmp_out:
468+
with tempfile.NamedTemporaryFile() as tmp_in:
469+
with open(tmp_in.name, "wb") as fd_in:
470+
fd_in.write(hsaco)
471+
amd.link_hsaco(tmp_in.name, tmp_out.name)
472+
with open(tmp_out.name, "rb") as fd_out:
473+
ret = fd_out.read()
447474
return ret
448475

449476
def add_stages(self, stages, options, language):

third_party/amd/backend/driver.c

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,42 @@
33
#include <hip/hip_runtime_api.h>
44
#define PY_SSIZE_T_CLEAN
55
#include <Python.h>
6-
#include <dlfcn.h>
76
#include <stdbool.h>
87
#include <stdio.h>
98
#include <stdlib.h>
109

10+
#ifdef _WIN32
11+
#include <windows.h>
12+
// Windows compatibility layer for dlopen/dlsym/dlclose/dlerror
13+
#define RTLD_NOW 0
14+
#define RTLD_LAZY 0
15+
#define RTLD_LOCAL 0
16+
static char dlerror_buf[512];
17+
static inline void *dlopen(const char *filename, int flags) {
18+
(void)flags;
19+
HMODULE h = LoadLibraryA(filename);
20+
if (!h) {
21+
snprintf(dlerror_buf, sizeof(dlerror_buf), "LoadLibrary failed with error %lu", GetLastError());
22+
}
23+
return (void *)h;
24+
}
25+
static inline void *dlsym(void *handle, const char *symbol) {
26+
void *p = (void *)GetProcAddress((HMODULE)handle, symbol);
27+
if (!p) {
28+
snprintf(dlerror_buf, sizeof(dlerror_buf), "GetProcAddress failed for %s with error %lu", symbol, GetLastError());
29+
}
30+
return p;
31+
}
32+
static inline int dlclose(void *handle) {
33+
return FreeLibrary((HMODULE)handle) ? 0 : -1;
34+
}
35+
static inline const char *dlerror(void) {
36+
return dlerror_buf[0] ? dlerror_buf : NULL;
37+
}
38+
#else
39+
#include <dlfcn.h>
40+
#endif
41+
1142
// The list of paths to search for the HIP runtime library. The caller Python
1243
// code should substitute the search path placeholder.
1344
static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};

0 commit comments

Comments
 (0)