Skip to content

Commit 0633f63

Browse files
henrylhtsangpytorchmergebot
authored andcommitted
[cutlass backend] try fix standlone runner test (pytorch#147811)
Differential Revision: [D70147859](https://our.internmc.facebook.com/intern/diff/D70147859/) Trying to fix this test one last time, especially when mixed mm is getting removed. Pull Request resolved: pytorch#147811 Approved by: https://github.com/chenyang78
1 parent 05bc8fe commit 0633f63

File tree

2 files changed

+16
-22
lines changed

2 files changed

+16
-22
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
from torch.testing import FileCheck
3535
from torch.testing._internal.common_cuda import SM80OrLater, SM90OrLater
3636
from torch.testing._internal.common_utils import (
37+
IN_RE_WORKER,
3738
instantiate_parametrized_tests,
39+
IS_FBCODE,
3840
parametrize,
3941
)
4042
from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA
@@ -1030,29 +1032,21 @@ def test_get_max_alignment(self):
10301032
m4, 4, "Wrong max alignment. Should have been 4 (due to float32 dtype )."
10311033
)
10321034

1033-
@unittest.skipIf(not SM80OrLater, "need sm_80")
1035+
@unittest.skipIf(not SM90OrLater, "need sm_90")
10341036
@mock.patch.dict(os.environ, {"PATH": _get_path_without_sccache()})
10351037
def test_standalone_runner(self):
10361038
max_autotune_gemm_backends = "CUTLASS"
10371039

1038-
def mm(a, b):
1039-
return torch.mm(a, b.to(torch.half))
1040-
1041-
m, n, k = 128, 16, 128
1042-
a = torch.randn(m, k).cuda().half()
1043-
b = torch.randint(0, 5, (n, k), dtype=torch.int8).cuda().T
1040+
a = torch.randn(128, 16).cuda().half()
1041+
b = torch.randn(16, 128).cuda().half()
10441042

10451043
with config.patch(
10461044
{
10471045
"max_autotune": True,
1048-
"autotune_in_subproc": True,
10491046
"max_autotune_gemm_backends": max_autotune_gemm_backends,
1050-
"cuda.cutlass_max_profiling_configs": 1,
1051-
"autotune_local_cache": True,
1047+
"cuda.cutlass_max_profiling_configs": 2,
10521048
"autotune_fallback_to_aten": False,
10531049
"cuda.generate_test_runner": True, # put standalone runner in the generated code
1054-
"use_mixed_mm": True,
1055-
"mixed_mm_choice": "aten",
10561050
}
10571051
):
10581052
from tempfile import NamedTemporaryFile
@@ -1065,9 +1059,9 @@ def mm(a, b):
10651059
# Run compilation, check results just in case, and save
10661060
# CUTLASS-based generated code.
10671061
with CUDACompileSourceCapturingContext() as ctx:
1068-
compiled = torch.compile(mm, dynamic=False)
1062+
compiled = torch.compile(torch.mm, dynamic=False)
10691063

1070-
expected = mm(a, b)
1064+
expected = torch.mm(a, b)
10711065
actual = compiled(a, b)
10721066

10731067
torch.testing.assert_close(actual, expected)
@@ -1092,14 +1086,12 @@ def mm(a, b):
10921086
Path(cu_file.name), Path(exe_file.name)
10931087
)
10941088

1095-
if config.is_fbcode():
1089+
if IS_FBCODE:
10961090
# hack to bypass the following error:
10971091
# error while loading shared libraries: IX}: invalid mode for dlopen(): Invalid argument
10981092
platform_path = sysconfig.get_config_var("LIBDIR")
1099-
link_str = " ".join(
1100-
[f"-L{platform_path}", "-Xlinker", f"-rpath={platform_path}"]
1101-
)
1102-
command = command.replace(link_str, " ")
1093+
cuda_path = os.path.realpath(os.path.join(platform_path, "libcuda.so"))
1094+
command = command.replace("-lcuda ", f"-L{cuda_path} ")
11031095

11041096
repro_message = (
11051097
f"Reproduce with: {command}\n"
@@ -1108,11 +1100,12 @@ def mm(a, b):
11081100
)
11091101

11101102
retcode = os.system(command)
1111-
assert retcode == 0, repro_message
1103+
self.assertEqual(retcode, 0, repro_message)
11121104

11131105
# Run the executable generated.
1114-
retcode = os.system(exe_file.name)
1115-
assert retcode == 0, repro_message
1106+
if not IS_FBCODE or not IN_RE_WORKER:
1107+
retcode = os.system(exe_file.name)
1108+
self.assertEqual(retcode, 0, repro_message)
11161109

11171110
# Remove temporary files.
11181111
os.remove(cu_file.name)

torch/testing/_internal/common_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def repro_env_var_prefix() -> str:
235235
implied_by_fn=lambda: os.getenv("TW_JOB_USER") == "sandcastle",
236236
include_in_repro=False,
237237
)
238+
IN_RE_WORKER: bool = os.environ.get("INSIDE_RE_WORKER") is not None
238239

239240
_is_fbcode_default = (
240241
hasattr(torch._utils_internal, "IS_FBSOURCE") and

0 commit comments

Comments
 (0)