@@ -221,11 +221,29 @@ def visit_For(self, node):
221221
222222
223223def _normalize_ty (ty ) -> str :
224- if isinstance (ty , type ):
225- return ty .__name__
226- elif isinstance (ty , str ):
227- return ty
228- return repr (ty )
224+ import triton .language .core as core
225+ if isinstance (ty , str ):
226+ ty = ty .strip ()
227+ if ty .startswith ("const " ):
228+ ty = ty .removeprefix ("const" )
229+ ty = _normalize_ty (ty )
230+ assert ty .startswith ("*" )
231+ return "*k" + ty [1 :]
232+ if ty .endswith ("*" ):
233+ return "*" + _normalize_ty (ty [:- 1 ])
234+ if ty .startswith ("*" ):
235+ return "*" + _normalize_ty (ty [1 :])
236+ if ty .startswith ("tl." ):
237+ return _normalize_ty (ty .removeprefix ("tl." ))
238+ elif isinstance (ty , core .pointer_type ):
239+ return f"*{ _normalize_ty (ty .element_ty )} "
240+ elif isinstance (ty , core .dtype ):
241+ ty = ty .name
242+ elif isinstance (ty , type ):
243+ ty = ty .__name__
244+ else :
245+ ty = str (ty )
246+ return type_canonicalisation_dict .get (ty .replace ("_t" , "" ), ty )
229247
230248
231249class KernelParam :
@@ -250,13 +268,13 @@ def annotation(self):
250268
251269 @cached_property
252270 def annotation_type (self ):
253- annotation = self .annotation
254- for ty1 , ty2 in [( "uint" , 'u' ), ( "int" , 'i' )] :
255- width = annotation [ annotation . find ( ty1 ) + len ( ty1 ) :]
256- if width and ty1 in annotation :
257- return f" { ty2 } { width } "
258- if annotation == "bool" :
259- return "u1"
271+ a = self .annotation
272+ if a . startswith ( "*k" ) :
273+ a = a [ 2 :]
274+ elif a . startswith ( "*" ) :
275+ a = a [ 1 :]
276+ if a in set ( type_canonicalisation_dict . values ()) :
277+ return self . annotation
260278 return ""
261279
262280 @cached_property
@@ -265,7 +283,9 @@ def is_constexpr(self):
265283
266284 @cached_property
267285 def is_const (self ):
268- return "const" in self .annotation and not self .is_constexpr
286+ if self .is_constexpr :
287+ return False
288+ return "const" in self .annotation or self .annotation .startswith ("*k" )
269289
270290 @property
271291 def default (self ):
@@ -289,7 +309,7 @@ def specialize_impl(arg, is_const=False, specialize_value=True, align=True):
289309 if arg is None :
290310 return ("constexpr" , None )
291311 elif isinstance (arg , bool ):
292- return ("i1 " , None )
312+ return ("u1 " , None )
293313 elif isinstance (arg , int ):
294314 key = specialize_extra (arg , "int" , align = align ) if specialize_value else None
295315 if arg == 1 and specialize_value :
@@ -381,7 +401,15 @@ def create_function_from_signature(sig, kparams, backend):
381401 align = 'False' if kp .do_not_specialize_on_alignment else 'True'
382402 ret = f"specialize_impl({ name } , { is_const } , { specialize } , { align } )"
383403 if kp .annotation_type :
384- specialization .append (f'("{ kp .annotation_type } ",) + { ret } [1:]' )
404+ if isinstance (kp .annotation_type , str ):
405+ if kp .annotation_type == "u1" or kp .annotation_type [:2 ] in ["fp" , "bf" ]:
406+ # we do not specialize non-constexpr floats and bools:
407+ specialize = False
408+ if specialize :
409+ specialization .append (f'("{ kp .annotation_type } ",) + { ret } [1:]' )
410+ else :
411+ # skip runtime specialization:
412+ specialization .append (f'("{ kp .annotation_type } ", None)' )
385413 else :
386414 specialization .append (f"{ ret } " )
387415
@@ -412,7 +440,12 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
412440
413441
414442type_canonicalisation_dict = {
415- "bool" : "i1" ,
443+ # we canonicalise all bools to be unsigned:
444+ "bool" : "u1" ,
445+ "int1" : "u1" ,
446+ "uint1" : "u1" ,
447+ "i1" : "u1" ,
448+ # floating-point dtypes:
416449 "float8e4nv" : "fp8e4nv" ,
417450 "float8e5" : "fp8e5" ,
418451 "float8e4b15" : "fp8e4b15" ,
@@ -422,14 +455,20 @@ def dynamic_func({", ".join(list(map(arg, sig.parameters.items())) + ["**options
422455 "float8_e5m2" : "fp8e5" ,
423456 "float8e5b16" : "fp8e5b16" ,
424457 "float8_e5m2fnuz" : "fp8e5b16" ,
458+ "half" : "fp16" ,
425459 "float16" : "fp16" ,
426460 "bfloat16" : "bf16" ,
461+ "float" : "fp32" ,
427462 "float32" : "fp32" ,
463+ "double" : "fp64" ,
428464 "float64" : "fp64" ,
465+ # signed integers:
429466 "int8" : "i8" ,
430467 "int16" : "i16" ,
468+ "int" : "i32" ,
431469 "int32" : "i32" ,
432470 "int64" : "i64" ,
471+ # unsigned integers:
433472 "uint8" : "u8" ,
434473 "uint16" : "u16" ,
435474 "uint32" : "u32" ,
0 commit comments