Skip to content

Commit 5a1c7c4

Browse files
alexsamardzicpytorchmergebot
authored andcommitted
Fix standalone runner for CUTLASS auto-tuning backend (pytorch#146764)
Pull Request resolved: pytorch#146764 Approved by: https://github.com/henrylhtsang ghstack dependencies: pytorch#146755
1 parent eb655a2 commit 5a1c7c4

File tree

3 files changed

+144
-11
lines changed

3 files changed

+144
-11
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -762,9 +762,10 @@ def mm(a, b):
762762
"max_autotune_gemm_backends": "CUTLASS",
763763
"cuda.cutlass_dir": _CUTLASS_DIR,
764764
"cuda.cutlass_max_profiling_configs": 2,
765-
"use_mixed_mm": True,
766765
"autotune_local_cache": True,
767766
"autotune_fallback_to_aten": False,
767+
"use_mixed_mm": True,
768+
"mixed_mm_choice": "aten", # to disable Triton
768769
}
769770
):
770771
Y_compiled = torch.compile(mm, dynamic=dynamic)(a, b)
@@ -972,6 +973,81 @@ def test_get_max_alignment(self):
972973
m4, 4, "Wrong max alignment. Should have been 4 (due to float32 dtype )."
973974
)
974975

976+
@unittest.skipIf(not SM80OrLater, "need sm_80")
977+
@unittest.mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
978+
def test_standalone_runner(self):
979+
max_autotune_gemm_backends = "CUTLASS"
980+
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
981+
982+
def mm(a, b):
983+
return torch.mm(a, b.to(torch.half))
984+
985+
m, n, k = 128, 16, 128
986+
a = torch.randn(m, k).cuda().half()
987+
b = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda().T
988+
989+
with config.patch(
990+
{
991+
"max_autotune": True,
992+
"autotune_in_subproc": True,
993+
"max_autotune_gemm_backends": max_autotune_gemm_backends,
994+
"cuda.cutlass_dir": _CUTLASS_DIR,
995+
"cuda.cutlass_max_profiling_configs": 1,
996+
"autotune_local_cache": True,
997+
"autotune_fallback_to_aten": False,
998+
"cuda.generate_test_runner": True, # put standalone runner in the generated code
999+
"use_mixed_mm": True,
1000+
"mixed_mm_choice": "aten",
1001+
}
1002+
):
1003+
import os
1004+
from tempfile import NamedTemporaryFile
1005+
1006+
from torch._inductor.codegen.cuda.cutlass_utils import (
1007+
cuda_standalone_runner_compile_command,
1008+
CUDACompileSourceCapturingContext,
1009+
)
1010+
1011+
# Run compilation, check results just in case, and save
1012+
# CUTLASS-based generated code.
1013+
with CUDACompileSourceCapturingContext() as ctx:
1014+
compiled = torch.compile(mm, dynamic=False)
1015+
1016+
expected = mm(a, b)
1017+
actual = compiled(a, b)
1018+
1019+
torch.testing.assert_close(actual, expected)
1020+
1021+
sources = ctx.sources
1022+
1023+
assert len(sources) >= 1
1024+
1025+
# Get names for temporary source and executable files.
1026+
cu_file = NamedTemporaryFile("w", suffix=".cu", delete=False)
1027+
cu_file.close()
1028+
exe_file = NamedTemporaryFile("w", suffix="", delete=False)
1029+
exe_file.close()
1030+
1031+
# Save the generated code into the .cu file.
1032+
with open(cu_file.name, "w") as file:
1033+
file.write(sources[0])
1034+
1035+
# Get command to compile .cu file, and run the
1036+
# compilation.
1037+
command = cuda_standalone_runner_compile_command(
1038+
cu_file.name, exe_file.name
1039+
)
1040+
retcode = os.system(command)
1041+
assert retcode == 0
1042+
1043+
# Run the executable generated.
1044+
retcode = os.system(exe_file.name)
1045+
assert retcode == 0
1046+
1047+
# Remove temporary files.
1048+
os.remove(cu_file.name)
1049+
os.remove(exe_file.name)
1050+
9751051

9761052
if __name__ == "__main__":
9771053
from torch._inductor.utils import is_big_gpu

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,10 +383,11 @@ def my_compile(source_code, dst_file_ext):
383383
self._compile_patch = mock.patch(
384384
"torch._inductor.codecache.CUDACodeCache.compile", my_compile
385385
)
386-
return self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr]
386+
self._compile_patch.__enter__(*args, **kwargs) # type: ignore[union-attr]
387+
return self
387388

388389
def __exit__(self, *args, **kwargs):
389-
return self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr]
390+
self._compile_patch.__exit__(*args, **kwargs) # type: ignore[union-attr]
390391

391392

392393
def cuda_standalone_runner_compile_command(srcpath: Path, exepath: Path):

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
extern "C" {
3838
PT_EXPORT {{kernel_call_signature}} {
3939
try {
40-
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
40+
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
4141
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
4242
using coord_t = cutlass::gemm::GemmCoord::Index;
4343
static cutlass::KernelHardwareInfo hw_info;
@@ -154,7 +154,7 @@
154154
extern "C" {
155155
PT_EXPORT {{kernel_call_signature}} {
156156
try {
157-
int64_t B = {{kernel.size(Y, 0, -3, default_value=1)}};
157+
int B = {{kernel.size(Y, 0, -3, default_value=1)}};
158158
using ElementComputeEpilogue = {{instance_type}}::ElementAccumulator;
159159
using coord_t = cutlass::gemm::GemmCoord::Index;
160160
static cutlass::KernelHardwareInfo hw_info;
@@ -266,8 +266,8 @@
266266
// Initialize GemmSparse arguments.
267267
arguments = {
268268
{
269-
static_cast<coord_t>({{M}}),
270-
static_cast<coord_t>({{N}}),
269+
static_cast<coord_t>(M),
270+
static_cast<coord_t>(N),
271271
static_cast<coord_t>(2 * K),
272272
}, // GemmCoord problem_size
273273
X_ref, // TensorRef<ElementA const, LayoutA> ref_A
@@ -304,20 +304,43 @@
304304
if (block.size()<=0) return false;
305305
Element scope_max(static_cast<Element>(max)), scope_min(static_cast<Element>(min));
306306
cutlass::reference::device::BlockFillRandomUniform(
307-
block.get(), block.size(), seed, scope_max, scope_min, 0);
307+
(Element*)block.get(), block.size(), seed, scope_max, scope_min, 0);
308308
309309
return true;
310310
}
311311
312+
{% if Meta is defined and Meta is not none %}
313+
template <class Element>
314+
bool initialize_block_meta(
315+
cutlass::DeviceAllocation<Element>& block,
316+
uint64_t seed) {
317+
if (block.size()<=0) return false;
318+
cutlass::reference::device::BlockFillRandomSparseMeta(
319+
(Element*)block.get(), block.size(), seed, {{instance_type}}::kMetaSizeInBits);
320+
return true;
321+
}
322+
{% endif %}
323+
312324
extern "C" int run_standalone(uint64_t seed, int repetitions) {
313325
std::cout << "Starting GEMM Standalone test run with seed " << seed << std::endl;
314326
size_t workspace_size = 0;
315327
size_t* workspace_size_ptr = &workspace_size;
316328
329+
int M = {{kernel.get_layout_args()[0]}};
330+
int N = {{kernel.get_layout_args()[1]}};
331+
int K = {{kernel.get_layout_args()[2]}};
332+
int lda = {{kernel.get_layout_args()[3]}};
333+
int ldb = {{kernel.get_layout_args()[4]}};
334+
int ldc = {{kernel.get_layout_args()[5]}};
335+
int ldd = {{kernel.get_layout_args()[6]}};
336+
317337
using ElementA = {{kernel.cutlass_dtype(X)}};
318338
using ElementB = {{kernel.cutlass_dtype(W)}};
319339
using ElementC = {{kernel.cutlass_dtype(Bias, default_dtype='uint8_t')}}; // may not be void
320340
using ElementD = {{kernel.cutlass_dtype(Y)}};
341+
{% if Meta is defined and Meta is not none %}
342+
using ElementE = {{kernel.cutlass_dtype(Meta)}};
343+
{% endif %}
321344
322345
cutlass::DeviceAllocation<ElementA> X_data({{kernel.max_valid_index(X)+1}});
323346
initialize_block(X_data, seed++);
@@ -326,6 +349,10 @@
326349
cutlass::DeviceAllocation<ElementC> Bias_data({{kernel.max_valid_index(Bias)+1}});
327350
initialize_block(Bias_data, seed++);
328351
cutlass::DeviceAllocation<ElementD> Y_data({{kernel.max_valid_index(Y)+1}});
352+
{% if Meta is defined and Meta is not none %}
353+
cutlass::DeviceAllocation<ElementE> Meta_data({{kernel.max_valid_index(Meta)+1}});
354+
initialize_block_meta(Meta_data, seed++);
355+
{% endif %}
329356
330357
cutlass::DeviceAllocation<uint8_t> workspace_data;
331358
// Call once with workspace_size_ptr set to get workspace size
@@ -466,6 +493,14 @@ def _get_extra_inputs_and_names(
466493
) -> tuple[Optional[Buffer], list[Optional[Buffer]], list[str]]:
467494
raise NotImplementedError
468495

496+
@abstractmethod
497+
def _update_arg_names_for_test_call_statement(
498+
self,
499+
arg_names: list[str],
500+
input_nodes: list[Buffer],
501+
) -> list[str]:
502+
raise NotImplementedError
503+
469504
def _add_cutlass_gemm_choices(
470505
self,
471506
choices: list[ChoiceCaller],
@@ -980,13 +1015,14 @@ def test_call_statement(
9801015
"""
9811016
_, __, arg_types = kernel.args.cpp_argdefs()
9821017
arg_names = [name.strip() for name in names_str.strip().split(",")]
983-
if input_nodes[2] is None:
984-
del arg_names[2]
1018+
arg_names = self._update_arg_names_for_test_call_statement(
1019+
arg_names, input_nodes
1020+
)
9851021
arguments = [
9861022
f"(({arg_type}){arg_name}_data.get())"
9871023
for arg_type, arg_name in zip(arg_types, arg_names)
9881024
]
989-
return f"{kernel.kernel_name}({', '.join(arguments)}, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);"
1025+
return f"{kernel.kernel_name}({', '.join(arguments)}, M, N, K, lda, ldb, ldc, ldd, workspace_size_ptr, (uint8_t*)workspace_data.get(), 0);" # noqa: B950
9901026

9911027

9921028
class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
@@ -1206,6 +1242,15 @@ def _get_extra_inputs_and_names(
12061242
names: list[str] = []
12071243
return (Bias, inputs, names)
12081244

1245+
def _update_arg_names_for_test_call_statement(
1246+
self,
1247+
arg_names: list[str],
1248+
input_nodes: list[Buffer],
1249+
) -> list[str]:
1250+
if input_nodes[2] is None:
1251+
del arg_names[2]
1252+
return arg_names
1253+
12091254
def render_gemm_arguments(
12101255
self,
12111256
argument_template: str,
@@ -1482,6 +1527,17 @@ def _get_extra_inputs_and_names(
14821527
names = ["Meta"]
14831528
return (Bias, inputs, names)
14841529

1530+
def _update_arg_names_for_test_call_statement(
1531+
self,
1532+
arg_names: list[str],
1533+
input_nodes: list[Buffer],
1534+
) -> list[str]:
1535+
if input_nodes[3] is None:
1536+
del arg_names[3]
1537+
if input_nodes[2] is None:
1538+
del arg_names[2]
1539+
return arg_names
1540+
14851541
def render_gemm_arguments(
14861542
self,
14871543
instance_type: str,

0 commit comments

Comments
 (0)