Skip to content

Commit cb78503

Browse files
authored
Annotations for elementary dtypes and pointers (#6152)
This introduces thorough support for annotations of elementary dtypes and pointers. Annotated kernel arguments can under certain circumstances skip `specialize_impl` leading to faster launch times. This is inspired by the patch provided by @saagarjha in the comments in #6064 but we go further by doing the logic statically -- when constructing `dynamic_func` -- rather than at runtime. We also take this opportunity to do some general code simplification and cleanup, including consistently using "u1" rather than "i1" to denote booleans in the specialisation key.
1 parent f8d5d1e commit cb78503

File tree

1 file changed

+55
-16
lines changed

1 file changed

+55
-16
lines changed

python/triton/runtime/jit.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,29 @@ def visit_For(self, node):
221221

222222

223223
def _normalize_ty(ty) -> str:
224-
if isinstance(ty, type):
225-
return ty.__name__
226-
elif isinstance(ty, str):
227-
return ty
228-
return repr(ty)
224+
import triton.language.core as core
225+
if isinstance(ty, str):
226+
ty = ty.strip()
227+
if ty.startswith("const "):
228+
ty = ty.removeprefix("const")
229+
ty = _normalize_ty(ty)
230+
assert ty.startswith("*")
231+
return "*k" + ty[1:]
232+
if ty.endswith("*"):
233+
return "*" + _normalize_ty(ty[:-1])
234+
if ty.startswith("*"):
235+
return "*" + _normalize_ty(ty[1:])
236+
if ty.startswith("tl."):
237+
return _normalize_ty(ty.removeprefix("tl."))
238+
elif isinstance(ty, core.pointer_type):
239+
return f"*{_normalize_ty(ty.element_ty)}"
240+
elif isinstance(ty, core.dtype):
241+
ty = ty.name
242+
elif isinstance(ty, type):
243+
ty = ty.__name__
244+
else:
245+
ty = str(ty)
246+
return type_canonicalisation_dict.get(ty.replace("_t", ""), ty)
229247

230248

231249
class KernelParam:
@@ -250,13 +268,13 @@ def annotation(self):
250268

251269
@cached_property
252270
def annotation_type(self):
253-
annotation = self.annotation
254-
for ty1, ty2 in [("uint", 'u'), ("int", 'i')]:
255-
width = annotation[annotation.find(ty1) + len(ty1):]
256-
if width and ty1 in annotation:
257-
return f"{ty2}{width}"
258-
if annotation == "bool":
259-
return "u1"
271+
a = self.annotation
272+
if a.startswith("*k"):
273+
a = a[2:]
274+
elif a.startswith("*"):
275+
a = a[1:]
276+
if a in set(type_canonicalisation_dict.values()):
277+
return self.annotation
260278
return ""
261279

262280
@cached_property
@@ -265,7 +283,9 @@ def is_constexpr(self):
265283

266284
@cached_property
267285
def is_const(self):
268-
return "const" in self.annotation and not self.is_constexpr
286+
if self.is_constexpr:
287+
return False
288+
return "const" in self.annotation or self.annotation.startswith("*k")
269289

270290
@property
271291
def default(self):
@@ -289,7 +309,7 @@ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
289309
if arg is None:
290310
return ("constexpr", None)
291311
elif isinstance(arg, bool):
292-
return ("i1", None)
312+
return ("u1", None)
293313
elif isinstance(arg, int):
294314
key = specialize_extra(arg, "int", align=align) if specialize_value else None
295315
if arg == 1 and specialize_value:
@@ -381,7 +401,15 @@ def create_function_from_signature(sig, kparams, backend):
381401
align = 'False' if kp.do_not_specialize_on_alignment else 'True'
382402
ret = f"specialize_impl({name}, {is_const}, {specialize}, {align})"
383403
if kp.annotation_type:
384-
specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
404+
if isinstance(kp.annotation_type, str):
405+
if kp.annotation_type == "u1" or kp.annotation_type[:2] in ["fp", "bf"]:
406+
# we do not specialize non-constexpr floats and bools:
407+
specialize = False
408+
if specialize:
409+
specialization.append(f'("{kp.annotation_type}",) + {ret}[1:]')
410+
else:
411+
# skip runtime specialization:
412+
specialization.append(f'("{kp.annotation_type}", None)')
385413
else:
386414
specialization.append(f"{ret}")
387415

@@ -412,7 +440,12 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
412440

413441

414442
type_canonicalisation_dict = {
415-
"bool": "i1",
443+
# we canonicalise all bools to be unsigned:
444+
"bool": "u1",
445+
"int1": "u1",
446+
"uint1": "u1",
447+
"i1": "u1",
448+
# floating-point dtypes:
416449
"float8e4nv": "fp8e4nv",
417450
"float8e5": "fp8e5",
418451
"float8e4b15": "fp8e4b15",
@@ -422,14 +455,20 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
422455
"float8_e5m2": "fp8e5",
423456
"float8e5b16": "fp8e5b16",
424457
"float8_e5m2fnuz": "fp8e5b16",
458+
"half": "fp16",
425459
"float16": "fp16",
426460
"bfloat16": "bf16",
461+
"float": "fp32",
427462
"float32": "fp32",
463+
"double": "fp64",
428464
"float64": "fp64",
465+
# signed integers:
429466
"int8": "i8",
430467
"int16": "i16",
468+
"int": "i32",
431469
"int32": "i32",
432470
"int64": "i64",
471+
# unsigned integers:
433472
"uint8": "u8",
434473
"uint16": "u16",
435474
"uint32": "u32",

0 commit comments

Comments
 (0)