3434from torch .testing import FileCheck
3535from torch .testing ._internal .common_cuda import SM80OrLater , SM90OrLater
3636from torch .testing ._internal .common_utils import (
37+ IN_RE_WORKER ,
3738 instantiate_parametrized_tests ,
39+ IS_FBCODE ,
3840 parametrize ,
3941)
4042from torch .testing ._internal .inductor_utils import HAS_CPU , HAS_CUDA
@@ -1030,29 +1032,21 @@ def test_get_max_alignment(self):
10301032 m4 , 4 , "Wrong max alignment. Should have been 4 (due to float32 dtype )."
10311033 )
10321034
1033- @unittest .skipIf (not SM80OrLater , "need sm_80 " )
1035+ @unittest .skipIf (not SM90OrLater , "need sm_90 " )
10341036 @mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
10351037 def test_standalone_runner (self ):
10361038 max_autotune_gemm_backends = "CUTLASS"
10371039
1038- def mm (a , b ):
1039- return torch .mm (a , b .to (torch .half ))
1040-
1041- m , n , k = 128 , 16 , 128
1042- a = torch .randn (m , k ).cuda ().half ()
1043- b = torch .randint (0 , 5 , (n , k ), dtype = torch .int8 ).cuda ().T
1040+ a = torch .randn (128 , 16 ).cuda ().half ()
1041+ b = torch .randn (16 , 128 ).cuda ().half ()
10441042
10451043 with config .patch (
10461044 {
10471045 "max_autotune" : True ,
1048- "autotune_in_subproc" : True ,
10491046 "max_autotune_gemm_backends" : max_autotune_gemm_backends ,
1050- "cuda.cutlass_max_profiling_configs" : 1 ,
1051- "autotune_local_cache" : True ,
1047+ "cuda.cutlass_max_profiling_configs" : 2 ,
10521048 "autotune_fallback_to_aten" : False ,
10531049 "cuda.generate_test_runner" : True , # put standalone runner in the generated code
1054- "use_mixed_mm" : True ,
1055- "mixed_mm_choice" : "aten" ,
10561050 }
10571051 ):
10581052 from tempfile import NamedTemporaryFile
@@ -1065,9 +1059,9 @@ def mm(a, b):
10651059 # Run compilation, check results just in case, and save
10661060 # CUTLASS-based generated code.
10671061 with CUDACompileSourceCapturingContext () as ctx :
1068- compiled = torch .compile (mm , dynamic = False )
1062+ compiled = torch .compile (torch . mm , dynamic = False )
10691063
1070- expected = mm (a , b )
1064+ expected = torch . mm (a , b )
10711065 actual = compiled (a , b )
10721066
10731067 torch .testing .assert_close (actual , expected )
@@ -1092,14 +1086,12 @@ def mm(a, b):
10921086 Path (cu_file .name ), Path (exe_file .name )
10931087 )
10941088
1095- if config . is_fbcode () :
1089+ if IS_FBCODE :
10961090 # hack to bypass the following error:
10971091 # error while loading shared libraries: IX}: invalid mode for dlopen(): Invalid argument
10981092 platform_path = sysconfig .get_config_var ("LIBDIR" )
1099- link_str = " " .join (
1100- [f"-L{ platform_path } " , "-Xlinker" , f"-rpath={ platform_path } " ]
1101- )
1102- command = command .replace (link_str , " " )
1093+ cuda_path = os .path .realpath (os .path .join (platform_path , "libcuda.so" ))
1094+ command = command .replace ("-lcuda " , f"-L{ cuda_path } " )
11031095
11041096 repro_message = (
11051097 f"Reproduce with: { command } \n "
@@ -1108,11 +1100,12 @@ def mm(a, b):
11081100 )
11091101
11101102 retcode = os .system (command )
1111- assert retcode == 0 , repro_message
1103+ self . assertEqual ( retcode , 0 , repro_message )
11121104
11131105 # Run the executable generated.
1114- retcode = os .system (exe_file .name )
1115- assert retcode == 0 , repro_message
1106+ if not IS_FBCODE or not IN_RE_WORKER :
1107+ retcode = os .system (exe_file .name )
1108+ self .assertEqual (retcode , 0 , repro_message )
11161109
11171110 # Remove temporary files.
11181111 os .remove (cu_file .name )
0 commit comments