4545HAS_CUDA = HAS_CUDA and not torch .version .hip
4646SM80OrLater = SM80OrLater and not torch .version .hip
4747SM90OrLater = SM90OrLater and not torch .version .hip
48- SM80 = SM80OrLater and torch .cuda .get_device_capability () == (8 , 0 )
4948
5049
5150def _get_path_without_sccache () -> str :
@@ -737,22 +736,14 @@ def forward(self, x, w):
737736 torch .testing .assert_close (expected , actual , atol = 0.01 , rtol = 0.01 )
738737
739738 # TODO: Enable dynamic test cases when dynamic support is added.
740- @unittest .skipIf (True , "disabled due to broken on A100" )
741- # error: TypeError: can't multiply sequence by non-int of type 'str'
742- @unittest .skipIf (not SM80 , "need sm_80 exactly" )
739+ @unittest .skipIf (not SM80OrLater or SM90OrLater , "need sm_8x exactly" )
743740 @parametrize ("dynamic" , (False ,))
744- @parametrize ("max_autotune_gemm_backends" , ("CUTLASS" , "CUTLASS,Triton,ATen" ))
745741 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
746- def test_max_autotune_cutlass_backend_mixed_mm (
747- self , dynamic : bool , max_autotune_gemm_backends : str
748- ):
742+ def test_max_autotune_cutlass_backend_mixed_mm (self , dynamic : bool ):
749743 """
750744 Make sure autotuning mm in sub processes work without crashes.
751745 """
752746
753- if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
754- return
755-
756747 def mm (a , b ):
757748 return torch .mm (a , b .to (torch .half ))
758749
@@ -768,7 +759,7 @@ def mm(a, b):
768759 {
769760 "max_autotune" : True ,
770761 "autotune_in_subproc" : True ,
771- "max_autotune_gemm_backends" : max_autotune_gemm_backends ,
762+ "max_autotune_gemm_backends" : "CUTLASS" ,
772763 "cuda.cutlass_dir" : _CUTLASS_DIR ,
773764 "cuda.cutlass_max_profiling_configs" : 2 ,
774765 "use_mixed_mm" : True ,
@@ -792,22 +783,16 @@ def mm(a, b):
792783 assert cutlass_kernels_count > 0
793784
794785 # TODO: Enable dynamic test cases when dynamic support is added.
795- @unittest .skipIf (True , "disabled due to broken on A100" )
796- # error: TypeError: can't multiply sequence by non-int of type 'str'
797- @unittest .skipIf (not SM80 , "need sm_80 exactly" )
786+ @unittest .skipIf (not SM80OrLater or SM90OrLater , "need sm_8x exactly" )
798787 @parametrize ("dynamic" , (False ,))
799- @parametrize ("max_autotune_gemm_backends" , ("CUTLASS" , "CUTLASS,Triton,ATen" ))
800788 @unittest .mock .patch .dict (os .environ , {"PATH" : _get_path_without_sccache ()})
801789 def test_max_autotune_cutlass_backend_sparse_semi_structured_mm (
802- self , dynamic : bool , max_autotune_gemm_backends : str
790+ self , dynamic : bool
803791 ):
804792 """
805793 Make sure autotuning mm in sub processes work without crashes.
806794 """
807795
808- if max_autotune_gemm_backends == "CUTLASS" and torch .version .hip :
809- return
810-
811796 SparseSemiStructuredTensor ._FORCE_CUTLASS = True
812797
813798 def mm (a , b ):
@@ -823,7 +808,7 @@ def mm(a, b):
823808 {
824809 "max_autotune" : True ,
825810 "autotune_in_subproc" : True ,
826- "max_autotune_gemm_backends" : max_autotune_gemm_backends ,
811+ "max_autotune_gemm_backends" : "CUTLASS" ,
827812 "cuda.cutlass_dir" : _CUTLASS_DIR ,
828813 "cuda.cutlass_max_profiling_configs" : 2 ,
829814 "autotune_local_cache" : True ,
0 commit comments