Skip to content

Commit cbab5f4

Browse files
authored
[Gluon] Change gl.warp_specialize API (#8527)
Functions and their individual arguments are passed as an array. All the arguments are just appended together in MLIR, but the `WarpSpecializeOp::canonicalize` method will clean up duplicate arguments.
1 parent 7bdcc6b commit cbab5f4

File tree

6 files changed

+134
-80
lines changed

6 files changed

+134
-80
lines changed

python/examples/gluon/01-attention-forward.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -840,12 +840,13 @@ def attention_kernel( #
840840

841841
chnls = (q_chnl, kv_chnl, o_chnl, epi_chnl, s0_chnl, s1_chnl, c0_chnl, c1_chnl, exp_turnstile)
842842
descs = (desc_q, desc_k, desc_v, desc_o)
843-
gl.warp_specialize((config, chnls, descs, M, STAGE), _attn_fwd_correction, (config, chnls, descs, M, STAGE), [
844-
_attn_fwd_softmax0,
845-
_attn_fwd_softmax1,
846-
_attn_fwd_mma,
847-
_attn_fwd_load,
848-
_attn_fwd_epilogue,
843+
gl.warp_specialize([
844+
(_attn_fwd_correction, (config, chnls, descs, M, STAGE)),
845+
(_attn_fwd_softmax0, (config, chnls, descs, M, STAGE)),
846+
(_attn_fwd_softmax1, (config, chnls, descs, M, STAGE)),
847+
(_attn_fwd_mma, (config, chnls, descs, M, STAGE)),
848+
(_attn_fwd_load, (config, chnls, descs, M, STAGE)),
849+
(_attn_fwd_epilogue, (config, chnls, descs, M, STAGE)),
849850
], [4, 4, 1, 1, 1], [192, 192, 24, 24, 24])
850851

851852
q_chnl.release()

python/test/gluon/test_consan.py

Lines changed: 70 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -742,8 +742,10 @@ def ws_kernel(output, FAILURE: ttgl.constexpr):
742742
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
743743
for i in range(2):
744744
mbarrier.init(bar.index(i), count=1)
745-
ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout),
746-
[ws_1], [4], [32])
745+
ttgl.warp_specialize([
746+
(ws_default, (smem, bar, FAILURE, blocked_layout)),
747+
(ws_1, (smem, bar, FAILURE, blocked_layout)),
748+
], [4], [32])
747749
mbarrier.wait(bar.index(1), phase=0)
748750
val = smem.index(0).load(blocked_layout)
749751
output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout)
@@ -796,8 +798,10 @@ def ws_kernel(output, FAILURE: ttgl.constexpr):
796798
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
797799
for i in range(2):
798800
mbarrier.init(bar.index(i), count=1)
799-
ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout),
800-
[ws_1], [4], [32])
801+
ttgl.warp_specialize([
802+
(ws_default, (smem, bar, FAILURE, blocked_layout)),
803+
(ws_1, (smem, bar, FAILURE, blocked_layout)),
804+
], [4], [32])
801805
mbarrier.wait(bar.index(1), phase=0)
802806
val = smem.index(0).load(blocked_layout)
803807
output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout)
@@ -859,8 +863,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
859863
bar = ttgl.allocate_shared_memory(ttgl.int64, [3, 1], mbarrier.MBarrierLayout())
860864
for i in range(3):
861865
mbarrier.init(bar.index(i), count=1)
862-
ttgl.warp_specialize((smem, bar, MISSING_BAR, blocked_layout), ws_default,
863-
(smem, bar, MISSING_BAR, blocked_layout), [ws_1, ws_2], [4, 4], [32, 32])
866+
ttgl.warp_specialize([
867+
(ws_default, (smem, bar, MISSING_BAR, blocked_layout)),
868+
(ws_1, (smem, bar, MISSING_BAR, blocked_layout)),
869+
(ws_2, (smem, bar, MISSING_BAR, blocked_layout)),
870+
], [4, 4], [32, 32])
864871
mbarrier.wait(bar.index(2), phase=0)
865872
val = smem.index(0).load(blocked_layout)
866873
output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout)
@@ -919,8 +926,11 @@ def kernel(output, FAILURE: ttgl.constexpr):
919926
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
920927
mbarrier.init(bar.index(0), count=2)
921928
mbarrier.init(bar.index(1), count=1)
922-
ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout),
923-
[ws_1, ws_2], [4, 4], [32, 32])
929+
ttgl.warp_specialize([
930+
(ws_default, (smem, bar, FAILURE, blocked_layout)),
931+
(ws_1, (smem, bar, FAILURE, blocked_layout)),
932+
(ws_2, (smem, bar, FAILURE, blocked_layout)),
933+
], [4, 4], [32, 32])
924934
mbarrier.wait(bar.index(1), phase=0)
925935
val = smem.index(0).load(blocked_layout)
926936
output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout)
@@ -1007,8 +1017,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
10071017
mbarrier.arrive(bar.index(2), count=1)
10081018
mbarrier.arrive(bar.index(3), count=1)
10091019

1010-
ttgl.warp_specialize((smem, bar, MISSING_BAR, blocked_layout), ws_default,
1011-
(smem, bar, MISSING_BAR, blocked_layout), [ws_1, ws_2], [4, 4], [32, 32])
1020+
ttgl.warp_specialize([
1021+
(ws_default, (smem, bar, MISSING_BAR, blocked_layout)),
1022+
(ws_1, (smem, bar, MISSING_BAR, blocked_layout)),
1023+
(ws_2, (smem, bar, MISSING_BAR, blocked_layout)),
1024+
], [4, 4], [32, 32])
10121025

10131026
output = torch.empty((XBLOCK, ), device=device, dtype=torch.float16)
10141027
kernel[(1, )](output, MISSING_BAR=MISSING_BAR, num_warps=4)
@@ -1072,8 +1085,10 @@ def kernel(output, FAILURE: ttgl.constexpr):
10721085

10731086
mbarrier.arrive(bar.index(2), count=1)
10741087

1075-
ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout), ws_default, (smem, bar, FAILURE, blocked_layout),
1076-
[ws_1], [4], [32])
1088+
ttgl.warp_specialize([
1089+
(ws_default, (smem, bar, FAILURE, blocked_layout)),
1090+
(ws_1, (smem, bar, FAILURE, blocked_layout)),
1091+
], [4], [32])
10771092

10781093
output = torch.empty((XBLOCK, ), device=device, dtype=torch.float16)
10791094
kernel[(1, )](output, FAILURE=FAILURE, num_warps=4)
@@ -1160,8 +1175,12 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
11601175
mbarrier.arrive(bar.index(2), count=2)
11611176
mbarrier.arrive(bar.index(3), count=2)
11621177

1163-
ttgl.warp_specialize((smem, bar, MISSING_BAR, blocked_layout), ws_default,
1164-
(smem, bar, MISSING_BAR, blocked_layout), [ws_1, ws_2, ws_3], [4, 4, 4], [32, 32, 32])
1178+
ttgl.warp_specialize([
1179+
(ws_default, (smem, bar, MISSING_BAR, blocked_layout)),
1180+
(ws_1, (smem, bar, MISSING_BAR, blocked_layout)),
1181+
(ws_2, (smem, bar, MISSING_BAR, blocked_layout)),
1182+
(ws_3, (smem, bar, MISSING_BAR, blocked_layout)),
1183+
], [4, 4, 4], [32, 32, 32])
11651184

11661185
output = torch.empty((XBLOCK, ), device=device, dtype=torch.float16)
11671186
kernel[(1, )](output, MISSING_BAR=MISSING_BAR, num_warps=4)
@@ -1225,8 +1244,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
12251244
bar = ttgl.allocate_shared_memory(ttgl.int64, [3, 1], mbarrier.MBarrierLayout())
12261245
for i in range(3):
12271246
mbarrier.init(bar.index(i), count=1)
1228-
ttgl.warp_specialize((smem, bar, MISSING_BAR), ws_default, (smem, bar, MISSING_BAR), [ws_1, ws_2], [2, 8],
1229-
[32, 32])
1247+
ttgl.warp_specialize([
1248+
(ws_default, (smem, bar, MISSING_BAR)),
1249+
(ws_1, (smem, bar, MISSING_BAR)),
1250+
(ws_2, (smem, bar, MISSING_BAR)),
1251+
], [2, 8], [32, 32])
12301252
mbarrier.wait(bar.index(2), phase=0)
12311253
val = smem.index(0).load(blocked_layout)
12321254
output_ptrs = output + ttgl.arange(0, XBLOCK, blocked_layout)
@@ -1291,8 +1313,10 @@ def kernel(input, FAILURE: ttgl.constexpr):
12911313
smem = ttgl.allocate_shared_memory(ttgl.float16, [4, XBLOCK], smem_layout)
12921314
blocked_layout: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[XBLOCK], threads_per_warp=[32],
12931315
warps_per_cta=[4], order=[0])
1294-
ttgl.warp_specialize((input, smem, FAILURE, blocked_layout, 0), ws_prog,
1295-
(input, smem, FAILURE, blocked_layout, 2), [ws_prog], [4], [32])
1316+
ttgl.warp_specialize([
1317+
(ws_prog, (input, smem, FAILURE, blocked_layout, 0)),
1318+
(ws_prog, (input, smem, FAILURE, blocked_layout, 2)),
1319+
], [4], [32])
12961320

12971321
input = torch.randn((XBLOCK, ), device=device, dtype=torch.float16)
12981322
kernel[(1, )](input, FAILURE=FAILURE, num_warps=4)
@@ -1346,8 +1370,10 @@ def kernel(input, FAILURE: ttgl.constexpr):
13461370
smem = ttgl.allocate_shared_memory(ttgl.float16, [2, XBLOCK], smem_layout)
13471371
bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout())
13481372
mbarrier.init(bar.index(0), count=1)
1349-
ttgl.warp_specialize((input, smem, bar, FAILURE, blocked_layout), ws_default,
1350-
(input, smem, bar, FAILURE, blocked_layout), [ws_1], [4], [32])
1373+
ttgl.warp_specialize([
1374+
(ws_default, (input, smem, bar, FAILURE, blocked_layout)),
1375+
(ws_1, (input, smem, bar, FAILURE, blocked_layout)),
1376+
], [4], [32])
13511377

13521378
input = torch.randn((XBLOCK, ), device=device, dtype=torch.float16)
13531379
kernel[(1, )](input, FAILURE=FAILURE, num_warps=4)
@@ -1402,8 +1428,10 @@ def kernel(FAILURE: ttgl.constexpr):
14021428
smem = ttgl.allocate_shared_memory(ttgl.float16, [2, XBLOCK, XBLOCK], smem_layout)
14031429
bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout())
14041430
mbarrier.init(bar.index(0), count=1)
1405-
ttgl.warp_specialize((smem, bar, FAILURE, blocked_layout, mma_layout), ws_default,
1406-
(smem, bar, FAILURE, blocked_layout), [ws_1], [4], [32])
1431+
ttgl.warp_specialize([
1432+
(ws_default, (smem, bar, FAILURE, blocked_layout, mma_layout)),
1433+
(ws_1, (smem, bar, FAILURE, blocked_layout)),
1434+
], [4], [32])
14071435

14081436
kernel[(1, )](FAILURE=FAILURE, num_warps=4)
14091437

@@ -1438,7 +1466,10 @@ def kernel():
14381466
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
14391467
mbarrier.init(bar.index(0), count=1)
14401468
mbarrier.init(bar.index(1), count=1)
1441-
ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32])
1469+
ttgl.warp_specialize([
1470+
(ws_default, (bar, )),
1471+
(ws_1, (bar, )),
1472+
], [4], [32])
14421473

14431474
kernel[(1, )](num_warps=4)
14441475

@@ -1505,7 +1536,10 @@ def kernel():
15051536
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
15061537
mbarrier.init(bar.index(0), count=2)
15071538
mbarrier.init(bar.index(1), count=2)
1508-
ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32])
1539+
ttgl.warp_specialize([
1540+
(ws_default, (bar, )),
1541+
(ws_1, (bar, )),
1542+
], [4], [32])
15091543

15101544
kernel[(1, )](num_warps=4)
15111545

@@ -1541,7 +1575,10 @@ def kernel():
15411575
bar = ttgl.allocate_shared_memory(ttgl.int64, [1, 1], mbarrier.MBarrierLayout())
15421576
mbarrier.init(bar.index(0), count=1)
15431577
mbarrier.arrive(bar.index(0), count=1)
1544-
ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32])
1578+
ttgl.warp_specialize([
1579+
(ws_default, (bar, )),
1580+
(ws_1, (bar, )),
1581+
], [4], [32])
15451582

15461583
kernel[(1, )](num_warps=4)
15471584

@@ -1582,7 +1619,10 @@ def kernel(input_desc):
15821619
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
15831620
mbarrier.init(bar.index(0), count=1)
15841621
mbarrier.init(bar.index(1), count=1)
1585-
ttgl.warp_specialize((input_desc, smem, bar), ws_default, (input_desc, smem, bar), [ws_1], [4], [32])
1622+
ttgl.warp_specialize([
1623+
(ws_default, (input_desc, smem, bar)),
1624+
(ws_1, (input_desc, smem, bar)),
1625+
], [4], [32])
15861626

15871627
input = torch.randn((XBLOCK, XBLOCK), device=device, dtype=torch.float16)
15881628
shared_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
@@ -1621,6 +1661,9 @@ def kernel():
16211661
bar = ttgl.allocate_shared_memory(ttgl.int64, [2, 1], mbarrier.MBarrierLayout())
16221662
mbarrier.init(bar.index(0), count=1)
16231663
mbarrier.init(bar.index(1), count=1)
1624-
ttgl.warp_specialize((bar, ), ws_default, (bar, ), [ws_1], [4], [32])
1664+
ttgl.warp_specialize([
1665+
(ws_default, (bar, )),
1666+
(ws_1, (bar, )),
1667+
], [4], [32])
16251668

16261669
kernel[(1, )](num_warps=4)

python/test/gluon/test_frontend.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -466,17 +466,17 @@ def test_warp_specialize():
466466
# CHECK-NEXT: [[A:%.*]] = tt.make_range {end = 1 : i32, start = 0 : i32}
467467
# CHECK-NEXT: [[B:%.*]] = tt.make_range {end = 2 : i32, start = 0 : i32}
468468
# CHECK-NEXT: [[C:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
469-
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
469+
# CHECK-NEXT: [[OUTS:%.*]]:3 = ttg.warp_specialize([[A]], [[B]], [[C]], [[A]], [[B]], [[C]]) {{.*}}requestedRegisters = array<i32: 24, 48>
470470
# CHECK-NEXT: default {
471471
# CHECK-NEXT: [[RESULTS:%.*]]:3 = tt.call @{{.*}}warp_specialize_default{{.*}}cconstexpr_42{{.*}}([[A]], [[B]], [[C]])
472472
# CHECK-NEXT: warp_yield [[RESULTS]]#0, [[RESULTS]]#1, [[RESULTS]]#2
473473
# CHECK-NEXT: }
474-
# CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
474+
# CHECK-NEXT: partition0(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>, %arg3: tensor<1xi32, [[BLOCKED]]>, %arg4: tensor<2xi32, [[BLOCKED]]>, %arg5: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
475475
# CHECK-NEXT: call @{{.*}}warp_specialize_worker0{{.*}}cconstexpr_42{{.*}}(%arg0, %arg1, %arg2)
476476
# CHECK-NEXT: warp_return
477477
# CHECK-NEXT: }
478-
# CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
479-
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}cconstexpr_42{{.*}}(%arg0, %arg1, %arg2)
478+
# CHECK-NEXT: partition1(%arg0: tensor<1xi32, [[BLOCKED]]>, %arg1: tensor<2xi32, [[BLOCKED]]>, %arg2: tensor<4xi32, [[BLOCKED]]>, %arg3: tensor<1xi32, [[BLOCKED]]>, %arg4: tensor<2xi32, [[BLOCKED]]>, %arg5: tensor<4xi32, [[BLOCKED]]>) num_warps(4) {
479+
# CHECK-NEXT: call @{{.*}}warp_specialize_worker1{{.*}}cconstexpr_42{{.*}}(%arg3, %arg4, %arg5)
480480
# CHECK-NEXT: warp_return
481481
# CHECK-NEXT: }
482482
# CHECK-NEXT: call @{{.*}}anchor{{.*}}([[OUTS]]#0)
@@ -487,14 +487,20 @@ def test_warp_specialize():
487487
c = ttgl.arange(0, 4, layout=layout)
488488
pair = Pair(a, b)
489489
e: ttgl.constexpr = 42
490-
a, b = ttgl.warp_specialize((pair, c, e), warp_specialize_default, (pair, c, e),
491-
[warp_specialize_worker0, warp_specialize_worker1], [4, 4], [24, 48])
490+
a, b = ttgl.warp_specialize([
491+
(warp_specialize_default, (pair, c, e)),
492+
(warp_specialize_worker0, (pair, c, e)),
493+
(warp_specialize_worker1, (pair, c, e)),
494+
], [4, 4], [24, 48])
492495
anchor(a)
493496
anchor(b)
494497

495498
# CHECK: ttg.warp_specialize([[A]], [[B]], [[C]])
496499
# CHECK: (tensor<1xi32, [[BLOCKED]]>, tensor<2xi32, [[BLOCKED]]>, tensor<4xi32, [[BLOCKED]]>) -> ()
497-
ttgl.warp_specialize((pair, c, e), warp_specialize_worker0, (pair, c, e), [warp_specialize_worker1], [4], [48])
500+
ttgl.warp_specialize([
501+
(warp_specialize_worker0, (pair, c, e)),
502+
(warp_specialize_worker1, (pair, c, e)),
503+
], [4], [48])
498504

499505

500506
@gluon.jit
@@ -535,7 +541,11 @@ def test_num_warps_caller_context():
535541
# CHECK: func private @{{.*}}ws_test_worker1{{.*}}_NW1() attributes {noinline = false, "ttg.num-warps" = 1 : i32}
536542
# CHECK: func private @{{.*}}ws_body{{.*}}_NW1"() attributes {noinline = false, "ttg.num-warps" = 1 : i32}
537543
# CHECK: func private @{{.*}}anchor{{.*}}_NW1(%arg0: tensor<128xi32, [[BLOCKED_NW1]]>) attributes {noinline = false, "ttg.num-warps" = 1 : i32}
538-
ttgl.warp_specialize((), ws_test_default, (), [ws_test_worker0, ws_test_worker1], [2, 1], [80, 80])
544+
ttgl.warp_specialize([
545+
(ws_test_default, ()),
546+
(ws_test_worker0, ()),
547+
(ws_test_worker1, ()),
548+
], [2, 1], [80, 80])
539549

540550

541551
@gluon.jit
@@ -2913,8 +2923,12 @@ def test_get_num_warps():
29132923
# CHECK: tt.func private @{{.*}}print_num_warps{{.*}}NW8
29142924
# CHECK-NEXT arith.constant 8 : i32
29152925
print_num_warps()
2916-
ttgl.warp_specialize((), print_num_warps, (), [print_num_warps, print_num_warps, print_num_warps], [1, 2, 8],
2917-
[24, 24, 24])
2926+
ttgl.warp_specialize([
2927+
(print_num_warps, ()),
2928+
(print_num_warps, ()),
2929+
(print_num_warps, ()),
2930+
(print_num_warps, ()),
2931+
], [1, 2, 8], [24, 24, 24])
29182932

29192933

29202934
def test_mismatch_shape_and_layout_rank():

python/triton/experimental/gluon/language/_core.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -493,16 +493,12 @@ def set_auto_layout(value, layout, _semantic=None):
493493

494494

495495
@builtin
496-
def warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps, worker_num_regs,
497-
_semantic=None, _generator=None):
496+
def warp_specialize(functions_and_args, worker_num_warps, worker_num_regs, _semantic=None, _generator=None):
498497
"""
499498
Create a warp-specialized execution region, partitioning work across warps.
500499
501500
Args:
502-
default_args (List[Any]): Arguments for the default region.
503-
default_partition (callable): Function to build the default execution region.
504-
worker_args (List[Any]): Arguments for each warp partition.
505-
worker_partitions (List[callable]): Functions for each warp partition.
501+
functions_and_args (List[Tuple[Callable, Any]]): List of functions and arguments for each partition.
506502
worker_num_warps (List[int]): Number of warps per partition.
507503
worker_num_regs (List[int]): Number of registers per partition.
508504
@@ -511,8 +507,7 @@ def warp_specialize(default_args, default_partition, worker_args, worker_partiti
511507
"""
512508
worker_num_warps = [_unwrap_if_constexpr(w) for w in worker_num_warps]
513509
worker_num_regs = [_unwrap_if_constexpr(r) for r in worker_num_regs]
514-
return _semantic.warp_specialize(default_args, default_partition, worker_args, worker_partitions, worker_num_warps,
515-
worker_num_regs, _generator)
510+
return _semantic.warp_specialize(functions_and_args, worker_num_warps, worker_num_regs, _generator)
516511

517512

518513
@builtin

0 commit comments

Comments
 (0)