Skip to content

Commit e885225

Browse files
aakhundovpytorchmergebot
authored andcommitted
Add persistent+TMA version of Triton mm and addmm (pytorch#142101)
This PR adds persistent+TMA versions (Triton template + the corresponding infra) for the `tuned_mm` and `tuned_addmm` lowerings. The persistent+TMA choices are added to the GEMM autotuning if (checked by the `use_triton_tma_template` helper): 1. The min. hardware and Triton version requirements are met for the TMA support. 2. The GEMM inputs are compatible with the Triton TMA API (i.e., 16-byte aligned and contiguous). 3. The `config.triton.enable_persistent_tma_matmul` is set to `True`. Additional notes: 1. As added in this PR, the TMA uses are not compatible with prolog / epilogue fusion. To this end, in the new Triton template we currently support: TMA-based loads of A/B, but no prologue fusion; epilogue fusion, but no TMA-based stores of C. TMA + fusion compatibility can be added as a follow-up. 2. The current Triton TMA API (`experimental_device_tensormap_create2d`) does not support strides. Due to this, we limit the applicability of the new Triton template to the cases where the inputs are contiguous. 3. The transposed layouts of A and / or B are supported by passing the constexpr flags to the kernel and adjusting the ordering of the block sizes accordingly in the kernel code (this should have no effect on the kernel perf, as decided at the Triton compilation time). 4. After the next Triton pin update, we can switch to the tensor descriptor API (landed recently in triton-lang/triton#5290) in the new Triton template, which should allow lifting 2 and 3 above. 5. The configs for the new Triton template in `persistent_mm_kernel_configs` are preliminary. We should do more perf exploration and possibly augment the config in a follow-up. 6. This PR is rebased onto and unifies with two related PRs landed previously: pytorch#142045 (some infra unification with the persistent+TMA template for _scaled_mm) and pytorch#134532 (add possibility to disable prolog fusion for selected choices). 7. The current Triton TMA API only supports 1D and 2D descriptors (even after triton-lang/triton#5290, see [here](https://github.com/triton-lang/triton/blob/9829ce87ccb333a2b264b3a80b39a534bfa865ac/python/triton/language/core.py#L1957)). For now, this blocks adding persistent+TMA template for `torch.bmm`. Pull Request resolved: pytorch#142101 Approved by: https://github.com/drisspg, https://github.com/eellison
1 parent 17b71e5 commit e885225

File tree

9 files changed

+453
-61
lines changed

9 files changed

+453
-61
lines changed

test/inductor/test_fp8.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
468468
w_inverse_scale,
469469
bias,
470470
)
471-
with config.patch({"triton.enable_persistent_tma_matmul": True}):
471+
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
472472
linear_compiled = torch.compile(
473473
linear, backend="inductor", mode="max-autotune"
474474
)
@@ -538,7 +538,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
538538
w_inverse_scale,
539539
bias,
540540
)
541-
with config.patch({"triton.enable_persistent_tma_matmul": True}):
541+
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
542542
linear_compiled = torch.compile(
543543
linear, backend="inductor", mode="max-autotune"
544544
)
@@ -596,7 +596,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
596596
w_inverse_scale,
597597
bias,
598598
)
599-
with config.patch({"triton.enable_persistent_tma_matmul": True}):
599+
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
600600
linear_compiled = torch.compile(
601601
linear, backend="inductor", mode="max-autotune"
602602
)
@@ -655,7 +655,7 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
655655
w_inverse_scale,
656656
bias,
657657
)
658-
with config.patch({"triton.enable_persistent_tma_matmul": True}):
658+
with config.patch({"triton.enable_persistent_tma_matmul": persistent_matmul}):
659659
linear_compiled = torch.compile(
660660
linear, backend="inductor", mode="max-autotune"
661661
)

test/inductor/test_max_autotune.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
parametrize,
2929
TEST_WITH_ROCM,
3030
)
31+
from torch.utils._triton import has_triton_tma_device
3132

3233

3334
aten = torch.ops.aten
@@ -212,6 +213,76 @@ def mm(a, b):
212213
with config.patch({"max_autotune": True, "autotune_in_subproc": True}):
213214
torch.compile(mm, dynamic=dynamic)(a, b)
214215

216+
@unittest.skipIf(
217+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
218+
)
219+
@parametrize("a_transposed", (False, True))
220+
@parametrize("b_transposed", (False, True))
221+
@parametrize("dynamic", (False, True))
222+
def test_max_autotune_regular_mm_persistent_tma(
223+
self,
224+
a_transposed: bool,
225+
b_transposed: bool,
226+
dynamic: bool,
227+
):
228+
def mm(a, b):
229+
# TMA requires 16-byte alignment: here we repeat the dims
230+
# by the factor of 8, as float16 is 2-byte. All dims are
231+
# repeated due to the possible transpositions below.
232+
a = a.repeat(8, 8)
233+
b = b.repeat(8, 8)
234+
235+
if a_transposed:
236+
a = a.T
237+
if b_transposed:
238+
b = b.T
239+
240+
return torch.mm(a, b)
241+
242+
M, N, K = 21, 31, 11
243+
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
244+
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
245+
246+
with config.patch(
247+
{
248+
"max_autotune": True,
249+
"autotune_fallback_to_aten": False,
250+
"triton.enable_persistent_tma_matmul": "1",
251+
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
252+
}
253+
):
254+
c_actual = torch.compile(mm, dynamic=dynamic)(a, b)
255+
c_expected = mm(a, b)
256+
257+
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
258+
259+
@unittest.skipIf(
260+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
261+
)
262+
@parametrize("dynamic", (False, True))
263+
def test_max_autotune_regular_mm_persistent_tma_illegal_alignment(self, dynamic):
264+
def mm(a, b):
265+
return torch.mm(a, b)
266+
267+
M, N, K = 21, 31, 11
268+
a = torch.randn(M, K).to(torch.float16).cuda()
269+
b = torch.randn(K, N).to(torch.float16).cuda()
270+
271+
with self.assertRaises(BackendCompilerFailed) as context, config.patch(
272+
{
273+
"max_autotune": True,
274+
"autotune_fallback_to_aten": False,
275+
"triton.enable_persistent_tma_matmul": "1",
276+
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
277+
}
278+
):
279+
torch.compile(mm, dynamic=dynamic)(a, b)
280+
281+
# Lowering to the persistent+TMA Triton template should be skipped
282+
# if any of the input inner dims are not 16-byte aligned. As a result,
283+
# given the config flags above, we should have no choices left.
284+
self.assertIn("NoValidChoicesError", str(context.exception))
285+
215286
@parametrize("dynamic", (False, True))
216287
def test_max_autotune_regular_mm_zero_size_input(self, dynamic: bool):
217288
"""
@@ -316,6 +387,79 @@ def addmm(x, a, b):
316387
Y = addmm(x, a, b)
317388
torch.testing.assert_close(Y_compiled, Y, atol=1e-2, rtol=1e-2)
318389

390+
@unittest.skipIf(
391+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
392+
)
393+
@parametrize("a_transposed", (False, True))
394+
@parametrize("b_transposed", (False, True))
395+
@parametrize("dynamic", (False, True))
396+
def test_max_autotune_addmm_persistent_tma(
397+
self,
398+
a_transposed: bool,
399+
b_transposed: bool,
400+
dynamic: bool,
401+
):
402+
def addmm(x, a, b):
403+
# TMA requires 16-byte alignment: here we repeat the dims
404+
# by the factor of 8, as float16 is 2-byte. All dims are
405+
# repeated due to the possible transpositions below.
406+
x = x.repeat(8)
407+
a = a.repeat(8, 8)
408+
b = b.repeat(8, 8)
409+
410+
if a_transposed:
411+
a = a.T
412+
if b_transposed:
413+
b = b.T
414+
415+
return torch.addmm(x, a, b)
416+
417+
M, N, K = 21, 31, 11
418+
a = torch.randn(*((K, M) if a_transposed else (M, K))).to(torch.float16).cuda()
419+
b = torch.randn(*((N, K) if b_transposed else (K, N))).to(torch.float16).cuda()
420+
x = torch.randn(N).to(torch.float16).cuda()
421+
422+
with config.patch(
423+
{
424+
"max_autotune": True,
425+
"autotune_fallback_to_aten": False,
426+
"triton.enable_persistent_tma_matmul": "1",
427+
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
428+
}
429+
):
430+
c_actual = torch.compile(addmm, dynamic=dynamic)(x, a, b)
431+
c_expected = addmm(x, a, b)
432+
433+
torch.testing.assert_close(c_actual, c_expected, atol=1e-2, rtol=1e-2)
434+
435+
@unittest.skipIf(
436+
not has_triton_tma_device(), "Need device-side TMA support in Triton"
437+
)
438+
@parametrize("dynamic", (False, True))
439+
def test_max_autotune_addmm_persistent_tma_illegal_alignment(self, dynamic):
440+
def addmm(x, a, b):
441+
return torch.addmm(x, a, b)
442+
443+
M, N, K = 21, 31, 11
444+
a = torch.randn(M, K).to(torch.float16).cuda()
445+
b = torch.randn(K, N).to(torch.float16).cuda()
446+
x = torch.randn(N).to(torch.float16).cuda()
447+
448+
with self.assertRaises(BackendCompilerFailed) as context, config.patch(
449+
{
450+
"max_autotune": True,
451+
"autotune_fallback_to_aten": False,
452+
"triton.enable_persistent_tma_matmul": "1",
453+
"test_configs.autotune_choice_name_regex": "mm_persistent_tma",
454+
}
455+
):
456+
torch.compile(addmm, dynamic=dynamic)(x, a, b)
457+
458+
# Lowering to the persistent+TMA Triton template should be skipped
459+
# if any of the input inner dims are not 16-byte aligned. As a result,
460+
# given the config flags above, we should have no choices left.
461+
self.assertIn("NoValidChoicesError", str(context.exception))
462+
319463
@parametrize("dynamic", (False, True))
320464
def test_max_autotune_addmm_zero_size_input(self, dynamic):
321465
"""

torch/_inductor/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,11 @@ class test_configs:
13681368

13691369
runtime_triton_dtype_assert = False
13701370

1371+
# regex to control the set of considered autotuning
1372+
# choices (aka configs) by name and / or description
1373+
autotune_choice_name_regex: Optional[str] = None
1374+
autotune_choice_desc_regex: Optional[str] = None
1375+
13711376

13721377
if TYPE_CHECKING:
13731378
from torch.utils._config_typing import * # noqa: F401, F403

torch/_inductor/kernel/mm.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@
3131
)
3232
from ..utils import (
3333
get_gpu_shared_memory,
34+
get_tma_workspace_arg,
3435
use_aten_gemm_kernels,
3536
use_ck_gemm_template,
3637
use_cpp_gemm_template,
3738
use_cutlass_template,
3839
use_max_autotune,
3940
use_triton_template,
41+
use_triton_tma_template,
4042
)
4143
from .mm_common import (
4244
_is_static_problem,
@@ -48,6 +50,9 @@
4850
mm_configs,
4951
mm_grid,
5052
mm_options,
53+
persistent_mm_configs,
54+
persistent_mm_grid,
55+
persistent_mm_options,
5156
triton_config,
5257
)
5358

@@ -128,6 +133,110 @@
128133
""",
129134
)
130135

136+
persistent_tma_mm_template = TritonTemplate(
137+
name="mm_persistent_tma",
138+
grid=persistent_mm_grid,
139+
source=r"""
140+
{{def_kernel("A", "B")}}
141+
M = {{size("A", 0)}}
142+
N = {{size("B", 1)}}
143+
K = {{size("A", 1)}}
144+
if M * N == 0:
145+
# early exit due to zero-size input(s)
146+
return
147+
148+
start_pid = tl.program_id(0)
149+
grid_m = tl.cdiv(M, BLOCK_M)
150+
grid_n = tl.cdiv(N, BLOCK_N)
151+
k_tiles = tl.cdiv(K, BLOCK_K)
152+
num_tiles = grid_m * grid_n
153+
tiles_per_SM = num_tiles // NUM_SMS
154+
if start_pid < num_tiles % NUM_SMS:
155+
tiles_per_SM += 1
156+
157+
tile_id = start_pid - NUM_SMS
158+
ki = -1
159+
160+
width = GROUP_M * grid_n
161+
rk_for_mask = tl.arange(0, BLOCK_K)
162+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
163+
164+
workspace_base = ws_ptr + start_pid * 2 * TMA_SIZE
165+
a_desc_ptr = workspace_base
166+
b_desc_ptr = workspace_base + TMA_SIZE
167+
168+
triton.language.extra.cuda.experimental_device_tensormap_create2d(
169+
desc_ptr=a_desc_ptr,
170+
global_address=A,
171+
load_size=[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
172+
global_size=[M, K] if A_ROW_MAJOR else [K, M],
173+
element_ty=A.dtype.element_ty,
174+
)
175+
triton.language.extra.cuda.experimental_device_tensormap_create2d(
176+
desc_ptr=b_desc_ptr,
177+
global_address=B,
178+
load_size=[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
179+
global_size=[K, N] if B_ROW_MAJOR else [N, K],
180+
element_ty=B.dtype.element_ty,
181+
)
182+
183+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
184+
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
185+
186+
pid_m = 0
187+
pid_n = 0
188+
rm = 0
189+
rn = 0
190+
191+
for _ in range(0, k_tiles * tiles_per_SM):
192+
ki = tl.where(ki == k_tiles - 1, 0, ki + 1)
193+
if ki == 0:
194+
tile_id += NUM_SMS
195+
# re-order program ID for better L2 performance
196+
group_id = tile_id // width
197+
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
198+
pid_m = group_id * GROUP_M + (tile_id % group_size)
199+
pid_n = (tile_id % width) // (group_size)
200+
201+
rm = pid_m * BLOCK_M
202+
rn = pid_n * BLOCK_N
203+
204+
rk = ki * BLOCK_K
205+
206+
a = tl._experimental_descriptor_load(
207+
a_desc_ptr,
208+
[rm, rk] if A_ROW_MAJOR else [rk, rm],
209+
[BLOCK_M, BLOCK_K] if A_ROW_MAJOR else [BLOCK_K, BLOCK_M],
210+
A.dtype.element_ty,
211+
)
212+
b = tl._experimental_descriptor_load(
213+
b_desc_ptr,
214+
[rk, rn] if B_ROW_MAJOR else [rn, rk],
215+
[BLOCK_K, BLOCK_N] if B_ROW_MAJOR else [BLOCK_N, BLOCK_K],
216+
B.dtype.element_ty,
217+
)
218+
if B_PROLOGUE_CAST_TYPE is not None:
219+
b = b.to(B_PROLOGUE_CAST_TYPE)
220+
acc += tl.dot(
221+
a if A_ROW_MAJOR else a.T,
222+
b if B_ROW_MAJOR else b.T,
223+
allow_tf32=ALLOW_TF32,
224+
)
225+
226+
if ki == k_tiles - 1:
227+
# rematerialize rm and rn to save registers
228+
rcm = rm + tl.arange(0, BLOCK_M)
229+
rcn = rn + tl.arange(0, BLOCK_N)
230+
idx_m = rcm[:, None]
231+
idx_n = rcn[None, :]
232+
mask = (idx_m < M) & (idx_n < N)
233+
234+
# inductor generates a suffix
235+
{{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=12)}}
236+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
237+
""",
238+
)
239+
131240

132241
# prevent duplication registration of extern functions
133242
@functools.lru_cache(None)
@@ -206,6 +315,22 @@ def tuned_mm(mat1, mat2, *, layout=None):
206315
layout=layout,
207316
**mm_options(config, m, n, k, layout),
208317
)
318+
if use_triton_tma_template(mat1, mat2):
319+
for config in persistent_mm_configs(
320+
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
321+
):
322+
persistent_tma_mm_template.maybe_append_choice(
323+
choices,
324+
input_nodes=(mat1, mat2),
325+
layout=layout,
326+
workspace_arg=get_tma_workspace_arg(
327+
num_tma_descriptors=2,
328+
device=mat1.get_device(),
329+
),
330+
**mm_options(config, m, n, k, layout),
331+
**persistent_mm_options(mat1, mat2),
332+
)
333+
209334
if is_nonzero and use_cutlass_template(layout, m, n, k):
210335
CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2])
211336

@@ -398,6 +523,24 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None):
398523
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
399524
)
400525

526+
if use_triton_tma_template(mat1, mat2):
527+
for config in persistent_mm_configs(
528+
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
529+
):
530+
persistent_tma_mm_template.maybe_append_choice(
531+
choices,
532+
input_nodes=(inp_expanded, mat1, mat2),
533+
layout=layout,
534+
workspace_arg=get_tma_workspace_arg(
535+
num_tma_descriptors=2,
536+
device=mat1.get_device(),
537+
),
538+
**mm_options(config, m, n, k, layout),
539+
**persistent_mm_options(mat1, mat2),
540+
prefix_args=1,
541+
epilogue_fn=addmm_epilogue(layout.dtype, alpha, beta),
542+
)
543+
401544
if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k):
402545
# Filter out a known cause of CUDA illegal memory access errors
403546
# broadcasting on the last dim of the bias term seems not to be working

0 commit comments

Comments
 (0)