Skip to content

Commit 09328eb

Browse files
alexsamardzicpytorchmergebot
authored andcommitted
Update auto-tuning support for _scaled_grouped_mm (pytorch#150944)
1. Enable strided inputs 2. Implement "2d/2d", "3d/2d" and "3d/3d" combinations of inputs 3. Fix non-TMA load variant 4. Replace experimental_device_tensormap_create2d with _experimental_make_tensor_descriptor 5. Fix cases when group size along K dimension is not multiple of block size along K 6. Updated meta registration 7. Update synthetic offsets creation Pull Request resolved: pytorch#150944 Approved by: https://github.com/ngimel
1 parent 1339e88 commit 09328eb

File tree

8 files changed

+624
-333
lines changed

8 files changed

+624
-333
lines changed

aten/src/ATen/native/cuda/Blas.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,7 @@ namespace {
15321532
"D, arg ",
15331533
arg_idx);
15341534
TORCH_CHECK(
1535-
scale.is_contiguous(), "scale_a must be contiguous for arg ", arg_idx);
1535+
scale.is_contiguous(), "scale must be contiguous for arg ", arg_idx);
15361536
TORCH_CHECK(
15371537
scale.size(0) == mat.size(dim) * scale_multiplier,
15381538
"scale must have the same length as mat for arg ",
@@ -1545,8 +1545,8 @@ namespace {
15451545
"D for arg ",
15461546
arg_idx);
15471547
TORCH_CHECK(
1548-
scale.stride(1),
1549-
"scale_a must be contiguous in the last dimension for arg ",
1548+
scale.stride(1) == 1,
1549+
"scale must be contiguous in the last dimension for arg ",
15501550
arg_idx);
15511551
TORCH_CHECK(
15521552
scale.size(0) == mat.size(0),
@@ -1610,6 +1610,7 @@ bool use_fast_accum) {
16101610

16111611

16121612
TORCH_CHECK(!bias.has_value(), "Bias not supported yet");
1613+
TORCH_CHECK(!scale_result.has_value(), "Scale result not supported yet");
16131614
TORCH_CHECK(offs.has_value() == (a_is_2d || b_is_2d), "Have to provide offsets if there is a 2d matrix");
16141615

16151616
if (offs.has_value()) {

test/test_matmul_cuda.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,7 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
16161616
for a, b, ascale, bscale, out in zip(alist, blist, ascalelist, bscalelist, outlist):
16171617
out_ref = torch._scaled_mm(a, b.t(), ascale.view(-1, 1), bscale.view(1, -1),
16181618
out_dtype=torch.bfloat16, use_fast_accum=use_fast_accum)
1619-
self.assertEqual(out, out_ref)
1619+
self.assertEqual(out, out_ref, atol=1e-1, rtol=1e-2)
16201620

16211621
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
16221622
@xfailIfSM100OrLater
@@ -1626,14 +1626,19 @@ def scaled_grouped_mm_helper(self, alist, blist, ascalelist, bscalelist, outlist
16261626
@parametrize("use_torch_compile", [False, True])
16271627
def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile):
16281628
device = "cuda"
1629-
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
1629+
m, n, k, n_groups = 16, 32, 64, 4 # all sizes have to be divisible by 16
16301630
a = torch.randn(m, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
16311631
b = torch.randn(n, k * n_groups + k * int(strided), device=device).to(torch.float8_e4m3fn)[:, :k * n_groups]
1632-
scale_a = torch.arange(m * n_groups, device=device, dtype=torch.float32) / 4
1633-
scale_b = torch.arange(n * n_groups, device=device, dtype=torch.float32) / 4
1632+
scale_a = torch.rand(m * n_groups, device=device, dtype=torch.float32)
1633+
scale_b = torch.rand(n * n_groups, device=device, dtype=torch.float32)
16341634
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
16351635
f = torch._scaled_grouped_mm
1636-
f = torch.compile(f) if use_torch_compile else f
1636+
f = torch.compile(
1637+
f,
1638+
options={
1639+
"max_autotune": True,
1640+
"max_autotune_gemm_backends": "TRITON",
1641+
}) if use_torch_compile else f
16371642
out = f(a, b.t(), scale_a, scale_b, offs=offs,
16381643
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
16391644
offs_cpu = offs.cpu()
@@ -1657,7 +1662,7 @@ def test_scaled_grouped_gemm_2d_2d(self, fast_accum, strided, use_torch_compile)
16571662
def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile):
16581663
device = "cuda"
16591664
s_int = int(strided)
1660-
m, n, k, n_groups = 16, 32, 16, 4
1665+
m, n, k, n_groups = 16, 32, 64, 4
16611666
a = torch.randn(m * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
16621667
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
16631668
self.assertTrue(a.is_contiguous() is not strided)
@@ -1666,11 +1671,16 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16661671
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
16671672
if check_zero_size:
16681673
offs[0] = offs[1]
1669-
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32)
1670-
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
1674+
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32)
1675+
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
16711676

16721677
f = torch._scaled_grouped_mm
1673-
f = torch.compile(f, dynamic=False) if use_torch_compile else f
1678+
f = torch.compile(
1679+
f,
1680+
options={
1681+
"max_autotune": True,
1682+
"max_autotune_gemm_backends": "TRITON",
1683+
}) if use_torch_compile else f
16741684
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
16751685
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
16761686

@@ -1682,7 +1692,7 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16821692
ascalelist.append(scale_a[start:offs_cpu[i]])
16831693
outlist.append(out[start:offs_cpu[i]])
16841694
start = offs_cpu[i]
1685-
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
1695+
self.scaled_grouped_mm_helper(alist, b, ascalelist, scale_b, outlist, fast_accum)
16861696

16871697

16881698
@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS")
@@ -1694,16 +1704,21 @@ def test_scaled_grouped_gemm_2d_3d(self, fast_accum, strided, use_torch_compile)
16941704
def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile):
16951705
device = "cuda"
16961706
s_int = int(strided)
1697-
m, n, k, n_groups = 16, 32, 16, 4
1707+
m, n, k, n_groups = 16, 32, 64, 4
16981708
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
16991709
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
17001710
self.assertTrue(a.is_contiguous() is not strided)
17011711
self.assertTrue(b.is_contiguous() is not strided)
1702-
scale_a = torch.ones(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
1703-
scale_b = torch.ones(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
1712+
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
1713+
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32).view(n_groups, n)
17041714

17051715
f = torch._scaled_grouped_mm
1706-
f = torch.compile(f) if use_torch_compile else f
1716+
f = torch.compile(
1717+
f,
1718+
options={
1719+
"max_autotune": True,
1720+
"max_autotune_gemm_backends": "TRITON",
1721+
}) if use_torch_compile else f
17071722
out = f(a, b.transpose(-2, -1), scale_a, scale_b,
17081723
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
17091724

@@ -1719,20 +1734,25 @@ def test_scaled_grouped_gemm_3d_3d(self, fast_accum, strided, use_torch_compile)
17191734
def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile):
17201735
device = "cuda"
17211736
s_int = int(strided)
1722-
m, n, k, n_groups = 16, 32, 16, 4
1737+
m, n, k, n_groups = 16, 32, 64, 4
17231738
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[::(1 + s_int), :, :k]
17241739
b = torch.randn(n * n_groups, k * (1 + s_int), device=device).to(torch.float8_e4m3fn)[:, :k]
17251740
self.assertTrue(a.is_contiguous() is not strided)
17261741
self.assertTrue(b.is_contiguous() is not strided)
1727-
scale_a = torch.arange(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
1728-
scale_b = torch.arange(n_groups * n, device="cuda", dtype=torch.float32)
1742+
scale_a = torch.rand(n_groups * m, device="cuda", dtype=torch.float32).view(n_groups, m)
1743+
scale_b = torch.rand(n_groups * n, device="cuda", dtype=torch.float32)
17291744
for check_zero_size in (True, False):
17301745
offs = torch.arange(n, n_groups * n + 1, n, device="cuda", dtype=torch.int32)
17311746
if check_zero_size:
17321747
offs[0] = offs[1]
17331748

17341749
f = torch._scaled_grouped_mm
1735-
f = torch.compile(f) if use_torch_compile else f
1750+
f = torch.compile(
1751+
f,
1752+
options={
1753+
"max_autotune": True,
1754+
"max_autotune_gemm_backends": "TRITON",
1755+
}) if use_torch_compile else f
17361756
out = f(a, b.transpose(-2, -1), scale_a, scale_b, offs=offs,
17371757
out_dtype=torch.bfloat16, use_fast_accum=fast_accum)
17381758
offs_cpu = offs.cpu()
@@ -1743,7 +1763,7 @@ def test_scaled_grouped_gemm_3d_2d(self, fast_accum, strided, use_torch_compile)
17431763
bscalelist.append(scale_b[start:offs_cpu[i]])
17441764
outlist.append(out[:, start:offs_cpu[i]])
17451765
start = offs_cpu[i]
1746-
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
1766+
self.scaled_grouped_mm_helper(a, blist, scale_a, bscalelist, outlist, fast_accum)
17471767

17481768

17491769
@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)

torch/_inductor/graph.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def mark_nodes_dislike_padding(
217217
aten.convolution,
218218
aten.convolution_backward,
219219
aten._scaled_mm,
220-
aten._scaled_grouped_mm,
221220
]
222221
)
223222
# what's a better way to collect the reduction ops?

torch/_inductor/kernel/mm_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def persistent_mm_grid(M: int, N: int, meta: dict[str, Any], *, cdiv, min):
3838

3939

4040
@SymbolicGridFn
41-
def persistent_grouped_mm_grid(m, n, meta):
41+
def persistent_grouped_mm_grid(*args):
42+
meta = args[-1]
4243
return (meta["NUM_SMS"], 1, 1)
4344

4445

0 commit comments

Comments
 (0)