Skip to content

Commit 9a42f01

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] EVT dynamic shapes support (pytorch#154835)
Pull Request resolved: pytorch#154835 Approved by: https://github.com/henrylhtsang ghstack dependencies: pytorch#154829
1 parent 5911f87 commit 9a42f01

File tree

5 files changed

+40
-18
lines changed

5 files changed

+40
-18
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,7 +1660,10 @@ def forward(self, a, b, extra_args):
16601660
@unittest.skipIf(not SM90OrLater, "need sm_90")
16611661
@use_evt_config
16621662
@evt_all_ops
1663-
def test_evt_multi_output(self, op):
1663+
@parametrize(
1664+
"dynamic", (False, True)
1665+
) # To not drastically increase test time we only test dynamic on this test
1666+
def test_evt_multi_output(self, op, dynamic):
16641667
class TestModel(torch.nn.Module):
16651668
def forward(self, a, b, extra_args):
16661669
acc = a @ b
@@ -1671,18 +1674,24 @@ def forward(self, a, b, extra_args):
16711674

16721675
M = 1024
16731676
N = 512
1674-
a = torch.ones(M, N).cuda().half()
1675-
b = torch.ones(N, N).cuda().half()
1676-
extra_args = gen_args(op, (M, N))
1677-
model = TestModel().cuda()
1677+
shapes = [(512, 512)] if not dynamic else [(1024, 64), (128, 256)]
1678+
for i, shape in enumerate(shapes):
1679+
M, N = shape
1680+
a = torch.ones(M, N).cuda().half()
1681+
b = torch.ones(N, N).cuda().half()
1682+
extra_args = gen_args(op, (M, N))
1683+
model = TestModel().cuda()
16781684

1679-
result = torch.compile(model)(a, b, extra_args)
1680-
ref_result = model(a, b, extra_args)
1685+
result = torch.compile(model)(a, b, extra_args)
1686+
ref_result = model(a, b, extra_args)
16811687

1682-
self.assertEqual(
1683-
torch._dynamo.utils.counters["inductor"]["cuda_epilogue_fusion_counter"], 2
1684-
)
1685-
torch.testing.assert_close(result, ref_result)
1688+
self.assertEqual(
1689+
torch._dynamo.utils.counters["inductor"][
1690+
"cuda_epilogue_fusion_counter"
1691+
],
1692+
2 * (i + 1),
1693+
)
1694+
torch.testing.assert_close(result, ref_result)
16861695

16871696
@unittest.skipIf(not SM90OrLater, "need sm_90")
16881697
@use_evt_config

test/inductor/test_cutlass_evt.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,9 @@ def test_evt_argument_codegen(self):
372372

373373
self.assertExpectedInline(
374374
_render_argument_type(
375-
epilogue_functor, _create_mock_buffer_name_map(EXAMPLE_TENSORS)
375+
epilogue_functor,
376+
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
377+
lambda x: int(x),
376378
),
377379
"""\
378380
{ /* thread */
@@ -427,7 +429,9 @@ def fn(accum, bias):
427429

428430
self.assertExpectedInline(
429431
_render_argument_type(
430-
epilogue_functor, _create_mock_buffer_name_map(example_tensors)
432+
epilogue_functor,
433+
_create_mock_buffer_name_map(example_tensors),
434+
lambda x: int(x),
431435
),
432436
"""\
433437
{ /* thread */
@@ -452,6 +456,7 @@ def test_evt_codegen(self):
452456
MockTileDescription(),
453457
EpilogueScheduleType.ScheduleAuto,
454458
_create_mock_buffer_name_map(EXAMPLE_TENSORS),
459+
lambda x: x, # static shapes
455460
)
456461
self.assertExpectedInline(
457462
code,

torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def trace(
101101
tile_description: TileDescription,
102102
epilogue_schedule: EpilogueScheduleType,
103103
name_to_buffer: dict[str, Buffer],
104+
size_hint_fn: Callable[[Union[Expr, int]], int],
104105
**kwargs: dict[str, Any],
105106
) -> tuple[str, str, str]:
106107
cuda_arch = int(cuda_env.get_cuda_arch()) # type: ignore[arg-type]
@@ -116,7 +117,7 @@ def trace(
116117
fusion_callbacks,
117118
)
118119
evt_name, evt_code = collective_epilogue.emit()
119-
evt_args = _render_argument_type(epilogue_functor, name_to_buffer)
120+
evt_args = _render_argument_type(epilogue_functor, name_to_buffer, size_hint_fn)
120121
return evt_name, evt_args, evt_code
121122

122123
# Based off of
@@ -144,6 +145,7 @@ def parse(self, example_inputs: dict[str, CutlassTensor]) -> None:
144145
def _render_argument_type(
145146
epilogue_functor: EpilogueFunctor,
146147
name_to_buffer: dict[str, Buffer],
148+
size_hint_fn: Callable[[Union[Expr, int]], int],
147149
) -> str:
148150
epilogue_thread_type = epilogue_functor.epilogue_thread_type
149151

@@ -162,7 +164,10 @@ def render_argument_type(name: str, t: CutlassArgType) -> None:
162164
buffer.writeline(f"{{}}, /* {name} */")
163165
else:
164166
fields = [
165-
(fname, _get_arg_from_node(ty, name_to_buffer[name]))
167+
(
168+
fname,
169+
_get_arg_from_node(ty, name_to_buffer[name], size_hint_fn),
170+
)
166171
for fname, ty in t._fields_
167172
]
168173
field_strs = [
@@ -194,7 +199,9 @@ def render_thread_type(name: str, t: CutlassArgType) -> None:
194199

195200
return buffer.getvalue()
196201

197-
def _get_arg_from_node(arg_ty: type, node: Buffer) -> str:
202+
def _get_arg_from_node(
203+
arg_ty: type, node: Buffer, size_hint_fn: Callable[[Union[Expr, int]], int]
204+
) -> str:
198205
from ..cuda_template import CUTLASSTemplate
199206

200207
# Today, arguments are either a pointer to the
@@ -206,7 +213,7 @@ def _get_arg_from_node(arg_ty: type, node: Buffer) -> str:
206213
):
207214
DEFAULT_STRIDE_LEN = 3
208215
assert len(node.get_layout().stride) <= DEFAULT_STRIDE_LEN
209-
stride = [int(x) for x in node.get_layout().stride]
216+
stride = [size_hint_fn(x) for x in node.get_layout().stride]
210217
for _ in range(DEFAULT_STRIDE_LEN - len(stride)):
211218
stride.append(0)
212219

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,6 +1400,7 @@ def _render_evt(
14001400
op.tile_description, # type: ignore[attr-defined]
14011401
op.epilogue_schedule, # type: ignore[attr-defined]
14021402
{k: name_to_buffer[v] for k, v in var_name_to_buffer_name.items()}, # type: ignore[arg-type,misc]
1403+
V.graph.sizevars.size_hint,
14031404
)
14041405

14051406
return (

torch/_inductor/sizevars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def size_hint(
574574

575575
def size_hints(
576576
self,
577-
exprs: Iterable[Expr],
577+
exprs: Iterable[Union[Expr, int]],
578578
*,
579579
fallback: Optional[int] = None,
580580
) -> tuple[int, ...]:

0 commit comments

Comments
 (0)