@@ -1616,7 +1616,7 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
16161616 for a , b , ascale , bscale , out in zip (alist , blist , ascalelist , bscalelist , outlist ):
16171617 out_ref = torch ._scaled_mm (a , b .t (), ascale .view (- 1 , 1 ), bscale .view (1 , - 1 ),
16181618 out_dtype = torch .bfloat16 , use_fast_accum = use_fast_accum )
1619- self .assertEqual (out , out_ref )
1619+ self .assertEqual (out , out_ref , atol = 1e-1 , rtol = 1e-2 )
16201620
16211621 @unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
16221622 @xfailIfSM100OrLater
@@ -1626,14 +1626,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
16261626 @parametrize ("use_torch_compile" , [False , True ])
16271627 def test_scaled_grouped_gemm_2d_2d (self , fast_accum , strided , use_torch_compile ):
16281628 device = "cuda"
1629- m , n , k , n_groups = 16 , 16 , 16 , 4 # all sizes have to be divisible by 16
1629+ m , n , k , n_groups = 16 , 32 , 64 , 4 # all sizes have to be divisible by 16
16301630 a = torch .randn (m , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
16311631 b = torch .randn (n , k * n_groups + k * int (strided ), device = device ).to (torch .float8_e4m3fn )[:, :k * n_groups ]
1632- scale_a = torch .arange (m * n_groups , device = device , dtype = torch .float32 ) / 4
1633- scale_b = torch .arange (n * n_groups , device = device , dtype = torch .float32 ) / 4
1632+ scale_a = torch .rand (m * n_groups , device = device , dtype = torch .float32 )
1633+ scale_b = torch .rand (n * n_groups , device = device , dtype = torch .float32 )
16341634 offs = torch .arange (k , n_groups * k + 1 , k , device = device , dtype = torch .int32 )
16351635 f = torch ._scaled_grouped_mm
1636- f = torch .compile (f ) if use_torch_compile else f
1636+ f = torch .compile (
1637+ f ,
1638+ options = {
1639+ "max_autotune" : True ,
1640+ "max_autotune_gemm_backends" : "TRITON" ,
1641+ }) if use_torch_compile else f
16371642 out = f (a , b .t (), scale_a , scale_b , offs = offs ,
16381643 out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
16391644 offs_cpu = offs .cpu ()
@@ -1657,7 +1662,7 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile)
16571662 def test_scaled_grouped_gemm_2d_3d (self , fast_accum , strided , use_torch_compile ):
16581663 device = "cuda"
16591664 s_int = int (strided )
1660- m , n , k , n_groups = 16 , 32 , 16 , 4
1665+ m , n , k , n_groups = 16 , 32 , 64 , 4
16611666 a = torch .randn (m * n_groups , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[:, :k ]
16621667 b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
16631668 self .assertTrue (a .is_contiguous () is not strided )
@@ -1666,11 +1671,16 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16661671 offs = torch .arange (m , n_groups * m + 1 , m , device = "cuda" , dtype = torch .int32 )
16671672 if check_zero_size :
16681673 offs [0 ] = offs [1 ]
1669- scale_a = torch .arange (n_groups * m , device = "cuda" , dtype = torch .float32 )
1670- scale_b = torch .ones (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
1674+ scale_a = torch .rand (n_groups * m , device = "cuda" , dtype = torch .float32 )
1675+ scale_b = torch .rand (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
16711676
16721677 f = torch ._scaled_grouped_mm
1673- f = torch .compile (f , dynamic = False ) if use_torch_compile else f
1678+ f = torch .compile (
1679+ f ,
1680+ options = {
1681+ "max_autotune" : True ,
1682+ "max_autotune_gemm_backends" : "TRITON" ,
1683+ }) if use_torch_compile else f
16741684 out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
16751685 out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
16761686
@@ -1682,7 +1692,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16821692 ascalelist .append (scale_a [start :offs_cpu [i ]])
16831693 outlist .append (out [start :offs_cpu [i ]])
16841694 start = offs_cpu [i ]
1685- self .scaled_grouped_mm_helper (alist , b , ascalelist , scale_b , outlist , fast_accum )
1695+ self .scaled_grouped_mm_helper (alist , b , ascalelist , scale_b , outlist , fast_accum )
16861696
16871697
16881698 @unittest .skipIf (TEST_WITH_ROCM , "ROCm doesn't support CUTLASS" )
@@ -1694,16 +1704,21 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16941704 def test_scaled_grouped_gemm_3d_3d (self , fast_accum , strided , use_torch_compile ):
16951705 device = "cuda"
16961706 s_int = int (strided )
1697- m , n , k , n_groups = 16 , 32 , 16 , 4
1707+ m , n , k , n_groups = 16 , 32 , 64 , 4
16981708 a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
16991709 b = torch .randn (n_groups * (1 + s_int ), n , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
17001710 self .assertTrue (a .is_contiguous () is not strided )
17011711 self .assertTrue (b .is_contiguous () is not strided )
1702- scale_a = torch .ones (n_groups * m , device = "cuda" , dtype = torch .float32 ).view (n_groups , m )
1703- scale_b = torch .ones (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
1712+ scale_a = torch .rand (n_groups * m , device = "cuda" , dtype = torch .float32 ).view (n_groups , m )
1713+ scale_b = torch .rand (n_groups * n , device = "cuda" , dtype = torch .float32 ).view (n_groups , n )
17041714
17051715 f = torch ._scaled_grouped_mm
1706- f = torch .compile (f ) if use_torch_compile else f
1716+ f = torch .compile (
1717+ f ,
1718+ options = {
1719+ "max_autotune" : True ,
1720+ "max_autotune_gemm_backends" : "TRITON" ,
1721+ }) if use_torch_compile else f
17071722 out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b ,
17081723 out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
17091724
@@ -1719,20 +1734,25 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
17191734 def test_scaled_grouped_gemm_3d_2d (self , fast_accum , strided , use_torch_compile ):
17201735 device = "cuda"
17211736 s_int = int (strided )
1722- m , n , k , n_groups = 16 , 32 , 16 , 4
1737+ m , n , k , n_groups = 16 , 32 , 64 , 4
17231738 a = torch .randn (n_groups * (1 + s_int ), m , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[::(1 + s_int ), :, :k ]
17241739 b = torch .randn (n * n_groups , k * (1 + s_int ), device = device ).to (torch .float8_e4m3fn )[:, :k ]
17251740 self .assertTrue (a .is_contiguous () is not strided )
17261741 self .assertTrue (b .is_contiguous () is not strided )
1727- scale_a = torch .arange (n_groups * m , device = "cuda" , dtype = torch .float32 ).view (n_groups , m )
1728- scale_b = torch .arange (n_groups * n , device = "cuda" , dtype = torch .float32 )
1742+ scale_a = torch .rand (n_groups * m , device = "cuda" , dtype = torch .float32 ).view (n_groups , m )
1743+ scale_b = torch .rand (n_groups * n , device = "cuda" , dtype = torch .float32 )
17291744 for check_zero_size in (True , False ):
17301745 offs = torch .arange (n , n_groups * n + 1 , n , device = "cuda" , dtype = torch .int32 )
17311746 if check_zero_size :
17321747 offs [0 ] = offs [1 ]
17331748
17341749 f = torch ._scaled_grouped_mm
1735- f = torch .compile (f ) if use_torch_compile else f
1750+ f = torch .compile (
1751+ f ,
1752+ options = {
1753+ "max_autotune" : True ,
1754+ "max_autotune_gemm_backends" : "TRITON" ,
1755+ }) if use_torch_compile else f
17361756 out = f (a , b .transpose (- 2 , - 1 ), scale_a , scale_b , offs = offs ,
17371757 out_dtype = torch .bfloat16 , use_fast_accum = fast_accum )
17381758 offs_cpu = offs .cpu ()
@@ -1743,7 +1763,7 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
17431763 bscalelist .append (scale_b [start :offs_cpu [i ]])
17441764 outlist .append (out [:, start :offs_cpu [i ]])
17451765 start = offs_cpu [i ]
1746- self .scaled_grouped_mm_helper (a , blist , scale_a , bscalelist , outlist , fast_accum )
1766+ self .scaled_grouped_mm_helper (a , blist , scale_a , bscalelist , outlist , fast_accum )
17471767
17481768
17491769 @unittest .skipIf (not PLATFORM_SUPPORTS_MX_GEMM , mx_skip_msg )
0 commit comments