Skip to content

Commit 5911f87

Browse files
mlazospytorchmergebot
authored andcommitted
[Cutlass] fp8 dynamic shapes test (pytorch#154829)
Pull Request resolved: pytorch#154829 Approved by: https://github.com/henrylhtsang, https://github.com/eellison
1 parent 606d73b commit 5911f87

File tree

6 files changed

+143
-31
lines changed

6 files changed

+143
-31
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 107 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -446,21 +446,26 @@ def test_max_autotune_cutlass_backend_regular_mm(
446446
Main test for mm.
447447
"""
448448

449-
class MyModel(torch.nn.Module):
450-
def forward(self, a, b):
451-
return a @ b
452-
453-
model = MyModel().cuda()
454449
# M, N, K
455450
shapes = [
456451
(128, 128, 16),
457452
(1024, 1024, 256),
458453
]
459-
shapes = shapes[0:1] if not dynamic else shapes
454+
455+
# M, N, K
456+
shapes = shapes if dynamic else shapes[0:1]
457+
458+
class MyModel(torch.nn.Module):
459+
def forward(self, a, b):
460+
return a @ b
461+
462+
model = MyModel().cuda()
463+
460464
inputs = [
461465
(torch.randn(M, K).cuda().to(dtype), torch.randn(K, N).cuda().to(dtype))
462466
for (M, N, K) in shapes
463467
]
468+
464469
dynamic_shapes = (
465470
{
466471
"a": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
@@ -483,11 +488,100 @@ def forward(self, a, b):
483488
model, inputs, dynamic_shapes=dynamic_shapes
484489
)
485490
else:
486-
compiled_model = torch.compile(model, dynamic=dynamic)
491+
compiled_model = torch.compile(model, dynamic=True)
487492
actual = [compiled_model(*input) for input in inputs]
488493

489494
torch.testing.assert_close(actual, expected)
490495

496+
@unittest.skipIf(not SM90OrLater, "need sm_90")
497+
@parametrize("dynamic", (False, True))
498+
@parametrize("use_aoti", (False,))
499+
@parametrize("dtype", (torch.float8_e4m3fn,))
500+
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
501+
def test_max_autotune_cutlass_backend_fp8_scaled_mm(
502+
self,
503+
dynamic: bool,
504+
max_autotune_gemm_backends: str = "CUTLASS",
505+
use_aoti: bool = False,
506+
dtype: torch.dtype = torch.float16,
507+
):
508+
"""
509+
Main test for mm.
510+
"""
511+
512+
# M, N, K
513+
shapes = [
514+
(128, 128, 16),
515+
(1024, 1024, 256),
516+
]
517+
518+
# M, N, K
519+
shapes = shapes if dynamic else shapes[0:1]
520+
521+
inputs = []
522+
for shape in shapes:
523+
M, N, K = shape
524+
output_dtype = torch.bfloat16
525+
device = "cuda"
526+
527+
x = torch.randn(M, K, dtype=output_dtype, device=device)
528+
w = torch.randn(N, K, dtype=output_dtype, device=device)
529+
530+
# quantize weight (prior to inference)
531+
w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype)
532+
w_t_fp8 = w_fp8.t()
533+
w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N)
534+
535+
# quantize input x
536+
x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype)
537+
538+
inputs.append((x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale))
539+
540+
class MyModel(torch.nn.Module):
541+
def forward(self, x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale):
542+
y = torch._scaled_mm(
543+
x_fp8,
544+
w_t_fp8,
545+
x_inverse_scale,
546+
w_inverse_scale,
547+
None,
548+
out_dtype=torch.bfloat16,
549+
use_fast_accum=False,
550+
)
551+
return y
552+
553+
dynamic_shapes = (
554+
{
555+
"x_fp8": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
556+
"x_inverse_scale": {0: Dim.DYNAMIC, 1: 1},
557+
"w_t_fp8": {0: Dim.DYNAMIC, 1: Dim.DYNAMIC},
558+
"w_inverse_scale": {0: 1, 1: Dim.DYNAMIC},
559+
}
560+
if dynamic
561+
else None
562+
)
563+
model = MyModel().cuda()
564+
565+
with config.patch(
566+
{
567+
"max_autotune": True,
568+
"max_autotune_gemm_backends": max_autotune_gemm_backends,
569+
"cuda.cutlass_max_profiling_configs": 2,
570+
"benchmark_epilogue_fusion": False, # EVT doesn't support benchmark fusion yet
571+
"cuda.cutlass_tma_only": True,
572+
}
573+
), dynamo_config.patch({"error_on_recompile": dynamic}):
574+
expected = [model(*input) for input in inputs]
575+
if use_aoti:
576+
actual = AOTIRunnerUtil.run_multiple(
577+
model, inputs, dynamic_shapes=dynamic_shapes
578+
)
579+
else:
580+
compiled_model = torch.compile(model, dynamic=True)
581+
actual = [compiled_model(*input) for input in inputs]
582+
583+
torch.testing.assert_close(actual, expected, rtol=1e-2, atol=0.05)
584+
491585
@unittest.skipIf(not SM90OrLater, "need sm_90")
492586
@parametrize("dynamic", (False, True))
493587
@parametrize("use_aoti", (False, True))
@@ -1648,9 +1742,9 @@ def test_gemm_operation_serialization(self, arch: str, cuda_version: str):
16481742
"shape",
16491743
(
16501744
(
1651-
16,
1652-
16,
1653-
32,
1745+
512,
1746+
128,
1747+
64,
16541748
),
16551749
),
16561750
)
@@ -1720,9 +1814,9 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
17201814
"shape",
17211815
(
17221816
(
1723-
16,
1724-
16,
1725-
32,
1817+
512,
1818+
128,
1819+
64,
17261820
),
17271821
),
17281822
)

test/inductor/test_cutlass_evt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,9 @@ def test_example_tensor_creation(self):
347347
)
348348
buffer_renames = {"buf0": "buf0", "buf1": "buf1", "acc": "buf0"}
349349
name_to_buffer = {"buf0": row_major_buf0, "buf1": col_major_buf1}
350-
result = create_example_tensors(buffer_renames, name_to_buffer)
350+
result = create_example_tensors(
351+
buffer_renames, name_to_buffer, lambda x: int(x)
352+
)
351353
self.assertEqual(result["acc"].shape, (3, 4, 1))
352354
self.assertEqual(result["acc"].stride, (4, 1, 0))
353355
self.assertEqual(

torch/_inductor/codegen/cuda/cuda_kernel.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch import dtype as torch_dtype
1313
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
1414
from torch._inductor.scheduler import BaseSchedulerNode
15-
from torch._inductor.utils import do_bench_using_profiling, Placeholder
15+
from torch._inductor.utils import do_bench_using_profiling, OrderedSet, Placeholder
1616
from torch.utils._sympy.value_ranges import ValueRanges
1717

1818
from .cutlass_utils import DTYPE_TO_CUTLASS_TYPE
@@ -81,6 +81,7 @@ class CUDAKernel(Kernel):
8181
def __init__(self, *args, **kwargs) -> None:
8282
super().__init__(*args, **kwargs)
8383
self.layout_args: dict[str, list[LayoutArg]] = defaultdict(list)
84+
self.size_args: list[Union[Expr, int]] = []
8485
# Mapping from arg name to IRNode.
8586
self.named_nodes: dict[str, IRNode] = {}
8687

@@ -172,6 +173,9 @@ def get_ld(node) -> Union[Expr, int]:
172173
LDD = get_ld(Y)
173174
return (M, N, K, B, LDA, LDB, LDC, LDD)
174175

176+
def get_dynamic_shape_args(self) -> list[Union[Expr, int]]:
177+
return [*self.get_layout_args(), *self.size_args]
178+
175179
@staticmethod
176180
def find_ld_idx(node: IRNode) -> int:
177181
strides = node.get_stride()
@@ -257,6 +261,7 @@ def def_kernel(
257261
e.g. The template might have input argument defined as [X, W, Bias],
258262
and the actual input passed into this template could be [Bias, X, W].
259263
In this case, the `input_reorder` would be [2, 0, 1].
264+
additional_size_args: Additional size arguments for epilogue inputs
260265
"""
261266
names = [x.strip() for x in names_str.strip().split(",")]
262267
if len(inputs) + len(outputs) != len(names):
@@ -276,17 +281,30 @@ def def_kernel(
276281
self.named_nodes[name] = node
277282
self.args.input_buffers[node.get_name()] = name
278283

284+
free_symbols: OrderedSet[Expr] = OrderedSet()
279285
for name, node in zip(names[len(inputs) : len(inputs) + len(outputs)], outputs):
280286
if node is not None:
281287
self.named_nodes[name] = node
282288
self.args.output_buffers[node.get_name()] = name
283289

290+
if name not in (
291+
"X",
292+
"W",
293+
"Bias",
294+
"Y",
295+
): # we handle these symbolic shapes explicitly
296+
for expr in itertools.chain(node.get_size(), node.get_stride()):
297+
if isinstance(expr, Expr):
298+
for s in expr.free_symbols:
299+
free_symbols.add(s) # type: ignore[arg-type]
300+
284301
arg_defs, *_ = self.args.cpp_argdefs(DTYPE_TO_CUTLASS_TYPE)
285302

286303
self.init_layout_args()
287-
size_args = [
288-
f"const int {s}" for s in ("M", "N", "K", "B", "lda", "ldb", "ldc", "ldd")
289-
]
304+
size_vars = ["M", "N", "K", "B", "lda", "ldb", "ldc", "ldd"]
305+
size_vars.extend(str(s) for s in free_symbols)
306+
self.size_args.extend(free_symbols)
307+
size_args = [f"const int {s}" for s in size_vars]
290308

291309
runtime_arg_decls = ",".join(
292310
[f"{arg.ty} {arg.name}" for arg in self.runtime_arg_info]
@@ -326,11 +344,11 @@ def call_kernel(
326344
else:
327345
_, call_args, _, arg_types = self.args.python_argdefs()
328346

329-
layout_args = self.get_layout_args()
330-
call_args.extend(layout_args) # type: ignore[arg-type]
347+
dynamic_shape_args = self.get_dynamic_shape_args()
348+
call_args.extend(dynamic_shape_args) # type: ignore[arg-type]
331349
for arg in self.runtime_arg_values:
332350
call_args.append(arg)
333-
arg_types.extend("int" for a in layout_args)
351+
arg_types.extend("int" for _ in dynamic_shape_args)
334352
for arg in self.runtime_arg_info:
335353
arg_types.append(arg.ty)
336354
# dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar

torch/_inductor/codegen/cuda/cuda_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def generate( # type: ignore[override]
116116
expected_args,
117117
)
118118
V.graph.sizevars.size_hints(map(sympy.expand, call_args[len(expected_args) :]))
119-
size_args = V.graph.sizevars.size_hints(kernel.get_layout_args())
119+
size_args = V.graph.sizevars.size_hints(kernel.get_dynamic_shape_args())
120120
extra_args = tuple(list(size_args) + self.get_runtime_arg_values(**kwargs))
121121

122122
kernel_hash = hashlib.sha256(code.encode("utf-8")).hexdigest()[:8]

torch/_inductor/codegen/cuda/cutlass_lib_extensions/evt_extensions.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Union
1+
from typing import Any, Callable, Union
2+
3+
from sympy import Expr
24

35
from torch._inductor.ir import (
46
ComputedBuffer,
@@ -61,18 +63,13 @@
6163
def create_example_tensors(
6264
var_name_to_buffer_name: dict[str, str],
6365
name_to_buffer: dict[str, Buffer],
66+
size_hint_fn: Callable[[Union[Expr, int]], int],
6467
) -> dict[str, CutlassTensor]:
6568
def cutlass_tensor_from_buffer(buffer: Buffer) -> CutlassTensor:
6669
shape = buffer.get_layout().size
6770
stride = buffer.get_layout().stride
68-
assert all(isinstance(x, int) or x.is_integer for x in shape), (
69-
f"{buffer.get_name()}'s shape {shape} contains symints which aren't supported for cutlass EVT"
70-
)
71-
assert all(isinstance(x, int) or x.is_integer for x in stride), (
72-
f"{buffer.get_name()}'s stride {stride} contains symints which aren't supported for cutlass EVT"
73-
)
74-
shape = tuple(int(x) for x in shape)
75-
stride = tuple(int(x) for x in stride)
71+
shape = tuple(size_hint_fn(x) for x in shape)
72+
stride = tuple(size_hint_fn(x) for x in stride)
7673

7774
is_row_major = is_contiguous_strides_for_shape(stride, shape)
7875
is_column_major = is_contiguous_strides_for_shape(stride[::-1], shape[::-1])

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,7 @@ def _render_evt(
13901390
examples = create_example_tensors(
13911391
var_name_to_buffer_name,
13921392
name_to_buffer, # type: ignore[arg-type]
1393+
V.graph.sizevars.size_hint,
13931394
)
13941395
evt_name, evt_args, evt_code = trace(
13951396
evt_py_code,

0 commit comments

Comments
 (0)