104104if TYPE_CHECKING :
105105 from types import ModuleType
106106
107+ from torch ._inductor .dtype_propagation import DtypePropagationOpsHandler
108+
107109 from ..ir import IRNode
108110
109111log = logging .getLogger (__name__ )
@@ -741,6 +743,12 @@ def update_on_args(self, name, args, kwargs):
741743 break
742744
743745
746+ def get_dtype_handler () -> DtypePropagationOpsHandler :
747+ from torch ._inductor .dtype_propagation import DtypePropagationOpsHandler
748+
749+ return DtypePropagationOpsHandler ()
750+
751+
744752def maybe_upcast_float32 (convert_output : bool = True ):
745753 """
746754 Codegen helper to upcast arguments to float32, depending on the config and dtype.
@@ -767,27 +775,17 @@ def wrapped(*args, **kwargs) -> str:
767775 upcast_args = [maybe_upcast_arg (arg ) for arg in args ]
768776 upcast_kwargs = {key : maybe_upcast_arg (val ) for key , val in kwargs .items ()}
769777
770- # Infer the output dtype from the inputs.
771- # This promotes to the largest input type.
772- all_args = args + tuple (kwargs .values ())
773- input_dtypes = [
774- var .dtype
775- for var in all_args
776- if isinstance (var , CSEVariable ) and var .dtype is not None
777- ]
778- result_dtype = (
779- functools .reduce (torch .promote_types , input_dtypes )
780- if len (input_dtypes ) > 0
781- else None
782- )
783-
784778 # Call the decorated function, optionally downcasting the result.
785779 result = func (* upcast_args , ** upcast_kwargs )
786- needs_downcast = (
787- convert_output
788- and any (needs_upcast (var ) for var in all_args )
789- and result_dtype not in (torch .float32 , None )
780+ any_needs_upcast = convert_output and any (
781+ needs_upcast (var ) for var in itertools .chain (args , kwargs .values ())
790782 )
783+ result_dtype = (
784+ None
785+ if not any_needs_upcast
786+ else getattr (get_dtype_handler (), func .__name__ )(* args , ** kwargs )
787+ )
788+ needs_downcast = result_dtype not in (torch .float32 , None )
791789 downcast_string = (
792790 f".to({ triton_type (result_dtype )} )"
793791 if needs_downcast and result_dtype is not None
@@ -910,6 +908,25 @@ def constant(cls, value, dtype):
910908 def abs (x ):
911909 return f"tl_math.abs({ x } )"
912910
911+ # TODO - register these ops as having divergent dtype
912+ # output if doing graph pass to remove consecutive casts
913+
914+ @staticmethod
915+ def truediv (x , y ):
916+ out = f"({ x } / { y } )"
917+ out_dtype = get_dtype_handler ().truediv (x , y )
918+ if out_dtype in (torch .float16 , torch .float32 ):
919+ out = f"{ out } .to({ triton_type (out_dtype )} )"
920+ return out
921+
922+ @staticmethod
923+ def mod (x , y ):
924+ out = f"({ x } % { y } )"
925+ out_dtype = get_dtype_handler ().mod (x , y )
926+ if out_dtype in (torch .float16 , torch .float32 ):
927+ out = f"{ out } .to({ triton_type (out_dtype )} )"
928+ return out
929+
913930 @staticmethod
914931 @maybe_upcast_float32 ()
915932 def libdevice_abs (x ):
0 commit comments