Skip to content

Commit 54ed13c

Browse files
Revert "Update low prec codegen for div/mod (pytorch#142350)"
This reverts commit ca97306. Reverted pytorch#142350 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think it. breaks an internal test ([comment](pytorch#142350 (comment)))
1 parent e885225 commit 54ed13c

File tree

3 files changed

+18
-48
lines changed

3 files changed

+18
-48
lines changed

test/inductor/test_op_dtype_prop.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,6 @@ def test_binary_math_mixed_precision(self):
212212
# There should be no downcast, since the input is promoted to float32.
213213
self.assertNotIn(".to(tl.float16)", code)
214214

215-
@config.patch("test_configs.runtime_triton_dtype_assert", True)
216-
@config.patch("triton.codegen_upcast_to_fp32", False)
217-
def test_downcast_div_mod(self):
218-
def fn(x, y):
219-
return x % y, x / y
220-
221-
x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2))
222-
223-
out, code = run_and_get_code(torch.compile(fn), x, y)
224-
FileCheck().check("static_assert").check_same(".dtype").run(code[0])
225-
self.assertEqual(fn(x, y), out)
226-
227215
@config.patch("test_configs.runtime_triton_dtype_assert", True)
228216
def test_constant(self):
229217
def fn():

test/inductor/test_pattern_matcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ def fn2(a, b, c):
175175
@skipIfXpu
176176
@skipCUDAIf(not SM80OrLater, "need sm_80")
177177
@inductor_config.patch(force_fuse_int_mm_with_mul=True)
178-
@inductor_config.patch("test_configs.runtime_triton_dtype_assert", True)
179178
def test_fused_int_mm_mul_epilogue(self):
180179
def fn1(a, b, c):
181180
return (

torch/_inductor/codegen/triton.py

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,6 @@
104104
if TYPE_CHECKING:
105105
from types import ModuleType
106106

107-
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
108-
109107
from ..ir import IRNode
110108

111109
log = logging.getLogger(__name__)
@@ -743,12 +741,6 @@ def update_on_args(self, name, args, kwargs):
743741
break
744742

745743

746-
def get_dtype_handler() -> DtypePropagationOpsHandler:
747-
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
748-
749-
return DtypePropagationOpsHandler()
750-
751-
752744
def maybe_upcast_float32(convert_output: bool = True):
753745
"""
754746
Codegen helper to upcast arguments to float32, depending on the config and dtype.
@@ -775,17 +767,27 @@ def wrapped(*args, **kwargs) -> str:
775767
upcast_args = [maybe_upcast_arg(arg) for arg in args]
776768
upcast_kwargs = {key: maybe_upcast_arg(val) for key, val in kwargs.items()}
777769

770+
# Infer the output dtype from the inputs.
771+
# This promotes to the largest input type.
772+
all_args = args + tuple(kwargs.values())
773+
input_dtypes = [
774+
var.dtype
775+
for var in all_args
776+
if isinstance(var, CSEVariable) and var.dtype is not None
777+
]
778+
result_dtype = (
779+
functools.reduce(torch.promote_types, input_dtypes)
780+
if len(input_dtypes) > 0
781+
else None
782+
)
783+
778784
# Call the decorated function, optionally downcasting the result.
779785
result = func(*upcast_args, **upcast_kwargs)
780-
any_needs_upcast = convert_output and any(
781-
needs_upcast(var) for var in itertools.chain(args, kwargs.values())
786+
needs_downcast = (
787+
convert_output
788+
and any(needs_upcast(var) for var in all_args)
789+
and result_dtype not in (torch.float32, None)
782790
)
783-
result_dtype = (
784-
None
785-
if not any_needs_upcast
786-
else getattr(get_dtype_handler(), func.__name__)(*args, **kwargs)
787-
)
788-
needs_downcast = result_dtype not in (torch.float32, None)
789791
downcast_string = (
790792
f".to({triton_type(result_dtype)})"
791793
if needs_downcast and result_dtype is not None
@@ -908,25 +910,6 @@ def constant(cls, value, dtype):
908910
def abs(x):
909911
return f"tl_math.abs({x})"
910912

911-
# TODO - register these ops as having divergent dtype
912-
# output if doing graph pass to remove consecutive casts
913-
914-
@staticmethod
915-
def truediv(x, y):
916-
out = f"({x} / {y})"
917-
out_dtype = get_dtype_handler().truediv(x, y)
918-
if out_dtype in (torch.float16, torch.float32):
919-
out = f"{out}.to({triton_type(out_dtype)})"
920-
return out
921-
922-
@staticmethod
923-
def mod(x, y):
924-
out = f"({x} % {y})"
925-
out_dtype = get_dtype_handler().mod(x, y)
926-
if out_dtype in (torch.float16, torch.float32):
927-
out = f"{out}.to({triton_type(out_dtype)})"
928-
return out
929-
930913
@staticmethod
931914
@maybe_upcast_float32()
932915
def libdevice_abs(x):

0 commit comments

Comments
 (0)