Skip to content

Commit 135a2d4

Browse files
eellisonpytorchmergebot
authored andcommitted
Update low prec codegen for div/mod (pytorch#142350)
Div/mod in fp16/bf16 requires a downcast to preserve its inputs' dtypes. Pull Request resolved: pytorch#142350 Approved by: https://github.com/blaine-rister
1 parent 15aee8e commit 135a2d4

File tree

3 files changed

+48
-18
lines changed

3 files changed

+48
-18
lines changed

test/inductor/test_op_dtype_prop.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,18 @@ def test_binary_math_mixed_precision(self):
212212
# There should be no downcast, since the input is promoted to float32.
213213
self.assertNotIn(".to(tl.float16)", code)
214214

215+
@config.patch("test_configs.runtime_triton_dtype_assert", True)
216+
@config.patch("triton.codegen_upcast_to_fp32", False)
217+
def test_downcast_div_mod(self):
218+
def fn(x, y):
219+
return x % y, x / y
220+
221+
x, y = (torch.rand([8], dtype=torch.float16, device="cuda") for _ in range(2))
222+
223+
out, code = run_and_get_code(torch.compile(fn), x, y)
224+
FileCheck().check("static_assert").check_same(".dtype").run(code[0])
225+
self.assertEqual(fn(x, y), out)
226+
215227
@config.patch("test_configs.runtime_triton_dtype_assert", True)
216228
def test_constant(self):
217229
def fn():

test/inductor/test_pattern_matcher.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def fn2(a, b, c):
175175
@skipIfXpu
176176
@skipCUDAIf(not SM80OrLater, "need sm_80")
177177
@inductor_config.patch(force_fuse_int_mm_with_mul=True)
178+
@inductor_config.patch("test_configs.runtime_triton_dtype_assert", True)
178179
def test_fused_int_mm_mul_epilogue(self):
179180
def fn1(a, b, c):
180181
return (

torch/_inductor/codegen/triton.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
if TYPE_CHECKING:
105105
from types import ModuleType
106106

107+
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
108+
107109
from ..ir import IRNode
108110

109111
log = logging.getLogger(__name__)
@@ -741,6 +743,12 @@ def update_on_args(self, name, args, kwargs):
741743
break
742744

743745

746+
def get_dtype_handler() -> DtypePropagationOpsHandler:
747+
from torch._inductor.dtype_propagation import DtypePropagationOpsHandler
748+
749+
return DtypePropagationOpsHandler()
750+
751+
744752
def maybe_upcast_float32(convert_output: bool = True):
745753
"""
746754
Codegen helper to upcast arguments to float32, depending on the config and dtype.
@@ -767,27 +775,17 @@ def wrapped(*args, **kwargs) -> str:
767775
upcast_args = [maybe_upcast_arg(arg) for arg in args]
768776
upcast_kwargs = {key: maybe_upcast_arg(val) for key, val in kwargs.items()}
769777

770-
# Infer the output dtype from the inputs.
771-
# This promotes to the largest input type.
772-
all_args = args + tuple(kwargs.values())
773-
input_dtypes = [
774-
var.dtype
775-
for var in all_args
776-
if isinstance(var, CSEVariable) and var.dtype is not None
777-
]
778-
result_dtype = (
779-
functools.reduce(torch.promote_types, input_dtypes)
780-
if len(input_dtypes) > 0
781-
else None
782-
)
783-
784778
# Call the decorated function, optionally downcasting the result.
785779
result = func(*upcast_args, **upcast_kwargs)
786-
needs_downcast = (
787-
convert_output
788-
and any(needs_upcast(var) for var in all_args)
789-
and result_dtype not in (torch.float32, None)
780+
any_needs_upcast = convert_output and any(
781+
needs_upcast(var) for var in itertools.chain(args, kwargs.values())
790782
)
783+
result_dtype = (
784+
None
785+
if not any_needs_upcast
786+
else getattr(get_dtype_handler(), func.__name__)(*args, **kwargs)
787+
)
788+
needs_downcast = result_dtype not in (torch.float32, None)
791789
downcast_string = (
792790
f".to({triton_type(result_dtype)})"
793791
if needs_downcast and result_dtype is not None
@@ -910,6 +908,25 @@ def constant(cls, value, dtype):
910908
def abs(x):
911909
return f"tl_math.abs({x})"
912910

911+
# TODO - register these ops as having divergent dtype
912+
# output if doing graph pass to remove consecutive casts
913+
914+
@staticmethod
915+
def truediv(x, y):
916+
out = f"({x} / {y})"
917+
out_dtype = get_dtype_handler().truediv(x, y)
918+
if out_dtype in (torch.float16, torch.float32):
919+
out = f"{out}.to({triton_type(out_dtype)})"
920+
return out
921+
922+
@staticmethod
923+
def mod(x, y):
924+
out = f"({x} % {y})"
925+
out_dtype = get_dtype_handler().mod(x, y)
926+
if out_dtype in (torch.float16, torch.float32):
927+
out = f"{out}.to({triton_type(out_dtype)})"
928+
return out
929+
913930
@staticmethod
914931
@maybe_upcast_float32()
915932
def libdevice_abs(x):

0 commit comments

Comments
 (0)