Skip to content

Commit 154bb02

Browse files
committed
add tvm-ffi
1 parent 6e7004e commit 154bb02

2 files changed

Lines changed: 107 additions & 5 deletions

File tree

kernels/src/kernels/variants.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@
2323
)
2424
from kernels.compat import has_torch, has_tvm_ffi
2525

26-
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-(cpu|cuda|metal|neuron|rocm|xpu)|tvm-ffi\d+\d+)")
26+
BUILD_VARIANT_REGEX = re.compile(
27+
r"^(torch\d+\d+"
28+
r"|torch-(cpu|cuda|metal|neuron|rocm|xpu|metal)"
29+
r"|tvm-ffi\d+\d+"
30+
r"|tvm-ffi-(cpu|cuda|metal|neuron|rocm|xpu))"
31+
)
2732

2833

2934
@dataclass(unsafe_hash=True)
@@ -134,6 +139,23 @@ def variant_str(self) -> str:
134139
return "torch"
135140

136141

142+
@strict
143+
@dataclass(unsafe_hash=True)
144+
class TvmFfiNoarch:
145+
"""Versionless tvm-ffi framework (noarch variants)."""
146+
147+
@staticmethod
148+
def possible_variants() -> list["TvmFfiNoarch"]:
149+
if has_tvm_ffi:
150+
return [TvmFfiNoarch()]
151+
else:
152+
return []
153+
154+
@property
155+
def variant_str(self) -> str:
156+
return "tvm-ffi"
157+
158+
137159
@strict
138160
@dataclass(unsafe_hash=True)
139161
class Arch:
@@ -220,12 +242,14 @@ def variant_str(self) -> str:
220242
class NoarchVariant:
221243
"""Noarch kernel build variant."""
222244

223-
framework: TorchNoarch
245+
framework: TorchNoarch | TvmFfiNoarch
224246
arch: Noarch
225247

226248
@staticmethod
227249
def possible_variants() -> list["NoarchVariant"]:
228-
frameworks = TorchNoarch.possible_variants()
250+
frameworks: list[TorchNoarch | TvmFfiNoarch] = (
251+
TorchNoarch.possible_variants() + TvmFfiNoarch.possible_variants()
252+
)
229253
archs = Noarch.possible_variants()
230254
return [NoarchVariant(framework=fw, arch=arch) for fw, arch in itertools.product(frameworks, archs)]
231255

@@ -264,6 +288,9 @@ def parse_variant(variant_str: str) -> Variant:
264288
framework_str = parts[0]
265289
arch_parts = parts[1:]
266290
return ArchVariant(framework=Torch.parse(framework_str), arch=Arch.parse(arch_parts))
291+
elif parts[0] == "tvm" and len(parts) >= 2 and parts[1] == "ffi":
292+
# noarch: e.g. "tvm-ffi-metal"
293+
return NoarchVariant(framework=TvmFfiNoarch(), arch=Noarch.parse("-".join(parts[2:])))
267294
elif parts[0] == "tvm" and len(parts) >= 2 and parts[1].startswith("ffi"):
268295
return ArchVariant(framework=TvmFfi.parse(f"tvm-{parts[1]}"), arch=Arch.parse(parts[2:]))
269296
else:
@@ -432,7 +459,8 @@ def _sort_variants(
432459
1. Torch arch kernels with with the highest compatible CUDA version.
433460
2. tvm-ffi arch kernels with with the highest compatible CUDA version.
434461
3. Torch noarch kernels.
435-
4. Old Torch universal kernels.
462+
4. tvm-ffi noarch kernels.
463+
5. Old universal kernels.
436464
"""
437465

438466
def sort_key(v: Variant) -> tuple:
@@ -447,6 +475,7 @@ def sort_key(v: Variant) -> tuple:
447475
else:
448476
assert isinstance(v, NoarchVariant)
449477
universal_order = 1 if v.arch.backend_name == "universal" else 0
450-
return (2, universal_order)
478+
framework_order = 0 if isinstance(v.framework, TorchNoarch) else 1
479+
return (2, universal_order, framework_order)
451480

452481
return sorted(variants, key=sort_key)

kernels/tests/test_variants.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@
5858
"torch-xpu",
5959
"torch-npu",
6060
"torch-universal",
61+
"tvm-ffi-cpu",
62+
"tvm-ffi-cuda",
63+
"tvm-ffi-metal",
64+
"tvm-ffi-rocm",
65+
"tvm-ffi-xpu",
66+
"tvm-ffi-universal",
6167
]
6268

6369
SUPERSET_VARIANT_STRINGS = [
@@ -259,6 +265,73 @@ def test_resolve_noarch_fallback():
259265
assert result[0].variant_str == "torch-cuda"
260266

261267

268+
RESOLVE_VARIANTS_TVM_FFI = [
269+
parse_variant(s)
270+
for s in [
271+
"tvm-ffi01-metal-aarch64-darwin",
272+
"tvm-ffi01-cu128-x86_64-linux",
273+
"tvm-ffi-metal",
274+
"tvm-ffi-cuda",
275+
]
276+
]
277+
278+
279+
def test_resolve_tvm_ffi_metal_darwin():
280+
# tvm-ffi arch metal build should be picked on Darwin/aarch64 when
281+
# tvm-ffi 0.1 is installed.
282+
from kernels.backends import Metal
283+
284+
result = _resolve_variant_for_system(
285+
variants=RESOLVE_VARIANTS_TVM_FFI,
286+
selected_backend=Metal(),
287+
cpu="aarch64",
288+
os="darwin",
289+
torch_version=None,
290+
torch_cxx11_abi=None,
291+
tvm_ffi_version=Version("0.1"),
292+
)
293+
assert result != []
294+
assert result[0].variant_str == "tvm-ffi01-metal-aarch64-darwin"
295+
296+
297+
def test_resolve_tvm_ffi_metal_noarch_fallback():
298+
# With no matching arch variant for the system, should fall back to
299+
# the tvm-ffi metal noarch variant.
300+
from kernels.backends import Metal
301+
302+
variants = [parse_variant(s) for s in ["tvm-ffi01-cu128-x86_64-linux", "tvm-ffi-metal"]]
303+
result = _resolve_variant_for_system(
304+
variants=variants,
305+
selected_backend=Metal(),
306+
cpu="aarch64",
307+
os="darwin",
308+
torch_version=None,
309+
torch_cxx11_abi=None,
310+
tvm_ffi_version=Version("0.1"),
311+
)
312+
assert result != []
313+
assert result[0].variant_str == "tvm-ffi-metal"
314+
315+
316+
def test_resolve_torch_noarch_preferred_over_tvm_ffi_noarch():
317+
# When both are available, Torch noarch is preferred over tvm-ffi noarch
318+
# (mirrors the arch precedence: Torch > tvm-ffi).
319+
variants = [parse_variant(s) for s in ["tvm-ffi-metal", "torch-metal"]]
320+
from kernels.backends import Metal
321+
322+
result = _resolve_variant_for_system(
323+
variants=variants,
324+
selected_backend=Metal(),
325+
cpu="aarch64",
326+
os="darwin",
327+
torch_version=Version("2.10"),
328+
torch_cxx11_abi=None,
329+
tvm_ffi_version=Version("0.1"),
330+
)
331+
assert result != []
332+
assert result[0].variant_str == "torch-metal"
333+
334+
262335
def test_resolve_no_match():
263336
result = _resolve_variant_for_system(
264337
variants=RESOLVE_VARIANTS,

0 commit comments

Comments
 (0)