2323)
2424from kernels .compat import has_torch , has_tvm_ffi
2525
26- BUILD_VARIANT_REGEX = re .compile (r"^(torch\d+\d+|torch-(cpu|cuda|metal|neuron|rocm|xpu)|tvm-ffi\d+\d+)" )
26+ BUILD_VARIANT_REGEX = re .compile (
27+ r"^(torch\d+\d+"
28+ r"|torch-(cpu|cuda|metal|neuron|rocm|xpu|metal)"
29+ r"|tvm-ffi\d+\d+"
30+ r"|tvm-ffi-(cpu|cuda|metal|neuron|rocm|xpu))"
31+ )
2732
2833
2934@dataclass (unsafe_hash = True )
@@ -134,6 +139,23 @@ def variant_str(self) -> str:
134139 return "torch"
135140
136141
142+ @strict
143+ @dataclass (unsafe_hash = True )
144+ class TvmFfiNoarch :
145+ """Versionless tvm-ffi framework (noarch variants)."""
146+
147+ @staticmethod
148+ def possible_variants () -> list ["TvmFfiNoarch" ]:
149+ if has_tvm_ffi :
150+ return [TvmFfiNoarch ()]
151+ else :
152+ return []
153+
154+ @property
155+ def variant_str (self ) -> str :
156+ return "tvm-ffi"
157+
158+
137159@strict
138160@dataclass (unsafe_hash = True )
139161class Arch :
@@ -220,12 +242,14 @@ def variant_str(self) -> str:
220242class NoarchVariant :
221243 """Noarch kernel build variant."""
222244
223- framework : TorchNoarch
245+ framework : TorchNoarch | TvmFfiNoarch
224246 arch : Noarch
225247
226248 @staticmethod
227249 def possible_variants () -> list ["NoarchVariant" ]:
228- frameworks = TorchNoarch .possible_variants ()
250+ frameworks : list [TorchNoarch | TvmFfiNoarch ] = (
251+ TorchNoarch .possible_variants () + TvmFfiNoarch .possible_variants ()
252+ )
229253 archs = Noarch .possible_variants ()
230254 return [NoarchVariant (framework = fw , arch = arch ) for fw , arch in itertools .product (frameworks , archs )]
231255
@@ -264,6 +288,9 @@ def parse_variant(variant_str: str) -> Variant:
264288 framework_str = parts [0 ]
265289 arch_parts = parts [1 :]
266290 return ArchVariant (framework = Torch .parse (framework_str ), arch = Arch .parse (arch_parts ))
291+ elif parts [0 ] == "tvm" and len (parts ) >= 2 and parts [1 ] == "ffi" :
292+ # noarch: e.g. "tvm-ffi-metal"
293+ return NoarchVariant (framework = TvmFfiNoarch (), arch = Noarch .parse ("-" .join (parts [2 :])))
267294 elif parts [0 ] == "tvm" and len (parts ) >= 2 and parts [1 ].startswith ("ffi" ):
268295 return ArchVariant (framework = TvmFfi .parse (f"tvm-{ parts [1 ]} " ), arch = Arch .parse (parts [2 :]))
269296 else :
@@ -432,7 +459,8 @@ def _sort_variants(
432459 1. Torch arch kernels with with the highest compatible CUDA version.
433460 2. tvm-ffi arch kernels with with the highest compatible CUDA version.
434461 3. Torch noarch kernels.
435- 4. Old Torch universal kernels.
462+ 4. tvm-ffi noarch kernels.
463+ 5. Old universal kernels.
436464 """
437465
438466 def sort_key (v : Variant ) -> tuple :
@@ -447,6 +475,7 @@ def sort_key(v: Variant) -> tuple:
447475 else :
448476 assert isinstance (v , NoarchVariant )
449477 universal_order = 1 if v .arch .backend_name == "universal" else 0
450- return (2 , universal_order )
478+ framework_order = 0 if isinstance (v .framework , TorchNoarch ) else 1
479+ return (2 , universal_order , framework_order )
451480
452481 return sorted (variants , key = sort_key )
0 commit comments