104104if TYPE_CHECKING :
105105 from types import ModuleType
106106
107- from torch ._inductor .dtype_propagation import DtypePropagationOpsHandler
108-
109107 from ..ir import IRNode
110108
111109log = logging .getLogger (__name__ )
@@ -743,12 +741,6 @@ def update_on_args(self, name, args, kwargs):
743741 break
744742
745743
746- def get_dtype_handler () -> DtypePropagationOpsHandler :
747- from torch ._inductor .dtype_propagation import DtypePropagationOpsHandler
748-
749- return DtypePropagationOpsHandler ()
750-
751-
752744def maybe_upcast_float32 (convert_output : bool = True ):
753745 """
754746 Codegen helper to upcast arguments to float32, depending on the config and dtype.
@@ -775,17 +767,27 @@ def wrapped(*args, **kwargs) -> str:
775767 upcast_args = [maybe_upcast_arg (arg ) for arg in args ]
776768 upcast_kwargs = {key : maybe_upcast_arg (val ) for key , val in kwargs .items ()}
777769
770+ # Infer the output dtype from the inputs.
771+ # This promotes to the largest input type.
772+ all_args = args + tuple (kwargs .values ())
773+ input_dtypes = [
774+ var .dtype
775+ for var in all_args
776+ if isinstance (var , CSEVariable ) and var .dtype is not None
777+ ]
778+ result_dtype = (
779+ functools .reduce (torch .promote_types , input_dtypes )
780+ if len (input_dtypes ) > 0
781+ else None
782+ )
783+
778784 # Call the decorated function, optionally downcasting the result.
779785 result = func (* upcast_args , ** upcast_kwargs )
780- any_needs_upcast = convert_output and any (
781- needs_upcast (var ) for var in itertools .chain (args , kwargs .values ())
786+ needs_downcast = (
787+ convert_output
788+ and any (needs_upcast (var ) for var in all_args )
789+ and result_dtype not in (torch .float32 , None )
782790 )
783- result_dtype = (
784- None
785- if not any_needs_upcast
786- else getattr (get_dtype_handler (), func .__name__ )(* args , ** kwargs )
787- )
788- needs_downcast = result_dtype not in (torch .float32 , None )
789791 downcast_string = (
790792 f".to({ triton_type (result_dtype )} )"
791793 if needs_downcast and result_dtype is not None
@@ -908,25 +910,6 @@ def constant(cls, value, dtype):
908910 def abs (x ):
909911 return f"tl_math.abs({ x } )"
910912
911- # TODO - register these ops as having divergent dtype
912- # output if doing graph pass to remove consecutive casts
913-
914- @staticmethod
915- def truediv (x , y ):
916- out = f"({ x } / { y } )"
917- out_dtype = get_dtype_handler ().truediv (x , y )
918- if out_dtype in (torch .float16 , torch .float32 ):
919- out = f"{ out } .to({ triton_type (out_dtype )} )"
920- return out
921-
922- @staticmethod
923- def mod (x , y ):
924- out = f"({ x } % { y } )"
925- out_dtype = get_dtype_handler ().mod (x , y )
926- if out_dtype in (torch .float16 , torch .float32 ):
927- out = f"{ out } .to({ triton_type (out_dtype )} )"
928- return out
929-
930913 @staticmethod
931914 @maybe_upcast_float32 ()
932915 def libdevice_abs (x ):
0 commit comments