@@ -742,8 +742,10 @@ def ws_kernel(output, FAILURE: ttgl.constexpr):
742742 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
743743 for i in range (2 ):
744744 mbarrier .init (bar .index (i ), count = 1 )
745- ttgl .warp_specialize ((smem , bar , FAILURE , blocked_layout ), ws_default , (smem , bar , FAILURE , blocked_layout ),
746- [ws_1 ], [4 ], [32 ])
745+ ttgl .warp_specialize ([
746+ (ws_default , (smem , bar , FAILURE , blocked_layout )),
747+ (ws_1 , (smem , bar , FAILURE , blocked_layout )),
748+ ], [4 ], [32 ])
747749 mbarrier .wait (bar .index (1 ), phase = 0 )
748750 val = smem .index (0 ).load (blocked_layout )
749751 output_ptrs = output + ttgl .arange (0 , XBLOCK , blocked_layout )
@@ -796,8 +798,10 @@ def ws_kernel(output, FAILURE: ttgl.constexpr):
796798 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
797799 for i in range (2 ):
798800 mbarrier .init (bar .index (i ), count = 1 )
799- ttgl .warp_specialize ((smem , bar , FAILURE , blocked_layout ), ws_default , (smem , bar , FAILURE , blocked_layout ),
800- [ws_1 ], [4 ], [32 ])
801+ ttgl .warp_specialize ([
802+ (ws_default , (smem , bar , FAILURE , blocked_layout )),
803+ (ws_1 , (smem , bar , FAILURE , blocked_layout )),
804+ ], [4 ], [32 ])
801805 mbarrier .wait (bar .index (1 ), phase = 0 )
802806 val = smem .index (0 ).load (blocked_layout )
803807 output_ptrs = output + ttgl .arange (0 , XBLOCK , blocked_layout )
@@ -859,8 +863,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
859863 bar = ttgl .allocate_shared_memory (ttgl .int64 , [3 , 1 ], mbarrier .MBarrierLayout ())
860864 for i in range (3 ):
861865 mbarrier .init (bar .index (i ), count = 1 )
862- ttgl .warp_specialize ((smem , bar , MISSING_BAR , blocked_layout ), ws_default ,
863- (smem , bar , MISSING_BAR , blocked_layout ), [ws_1 , ws_2 ], [4 , 4 ], [32 , 32 ])
866+ ttgl .warp_specialize ([
867+ (ws_default , (smem , bar , MISSING_BAR , blocked_layout )),
868+ (ws_1 , (smem , bar , MISSING_BAR , blocked_layout )),
869+ (ws_2 , (smem , bar , MISSING_BAR , blocked_layout )),
870+ ], [4 , 4 ], [32 , 32 ])
864871 mbarrier .wait (bar .index (2 ), phase = 0 )
865872 val = smem .index (0 ).load (blocked_layout )
866873 output_ptrs = output + ttgl .arange (0 , XBLOCK , blocked_layout )
@@ -919,8 +926,11 @@ def kernel(output, FAILURE: ttgl.constexpr):
919926 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
920927 mbarrier .init (bar .index (0 ), count = 2 )
921928 mbarrier .init (bar .index (1 ), count = 1 )
922- ttgl .warp_specialize ((smem , bar , FAILURE , blocked_layout ), ws_default , (smem , bar , FAILURE , blocked_layout ),
923- [ws_1 , ws_2 ], [4 , 4 ], [32 , 32 ])
929+ ttgl .warp_specialize ([
930+ (ws_default , (smem , bar , FAILURE , blocked_layout )),
931+ (ws_1 , (smem , bar , FAILURE , blocked_layout )),
932+ (ws_2 , (smem , bar , FAILURE , blocked_layout )),
933+ ], [4 , 4 ], [32 , 32 ])
924934 mbarrier .wait (bar .index (1 ), phase = 0 )
925935 val = smem .index (0 ).load (blocked_layout )
926936 output_ptrs = output + ttgl .arange (0 , XBLOCK , blocked_layout )
@@ -1007,8 +1017,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
10071017 mbarrier .arrive (bar .index (2 ), count = 1 )
10081018 mbarrier .arrive (bar .index (3 ), count = 1 )
10091019
1010- ttgl .warp_specialize ((smem , bar , MISSING_BAR , blocked_layout ), ws_default ,
1011- (smem , bar , MISSING_BAR , blocked_layout ), [ws_1 , ws_2 ], [4 , 4 ], [32 , 32 ])
1020+ ttgl .warp_specialize ([
1021+ (ws_default , (smem , bar , MISSING_BAR , blocked_layout )),
1022+ (ws_1 , (smem , bar , MISSING_BAR , blocked_layout )),
1023+ (ws_2 , (smem , bar , MISSING_BAR , blocked_layout )),
1024+ ], [4 , 4 ], [32 , 32 ])
10121025
10131026 output = torch .empty ((XBLOCK , ), device = device , dtype = torch .float16 )
10141027 kernel [(1 , )](output , MISSING_BAR = MISSING_BAR , num_warps = 4 )
@@ -1072,8 +1085,10 @@ def kernel(output, FAILURE: ttgl.constexpr):
10721085
10731086 mbarrier .arrive (bar .index (2 ), count = 1 )
10741087
1075- ttgl .warp_specialize ((smem , bar , FAILURE , blocked_layout ), ws_default , (smem , bar , FAILURE , blocked_layout ),
1076- [ws_1 ], [4 ], [32 ])
1088+ ttgl .warp_specialize ([
1089+ (ws_default , (smem , bar , FAILURE , blocked_layout )),
1090+ (ws_1 , (smem , bar , FAILURE , blocked_layout )),
1091+ ], [4 ], [32 ])
10771092
10781093 output = torch .empty ((XBLOCK , ), device = device , dtype = torch .float16 )
10791094 kernel [(1 , )](output , FAILURE = FAILURE , num_warps = 4 )
@@ -1160,8 +1175,12 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
11601175 mbarrier .arrive (bar .index (2 ), count = 2 )
11611176 mbarrier .arrive (bar .index (3 ), count = 2 )
11621177
1163- ttgl .warp_specialize ((smem , bar , MISSING_BAR , blocked_layout ), ws_default ,
1164- (smem , bar , MISSING_BAR , blocked_layout ), [ws_1 , ws_2 , ws_3 ], [4 , 4 , 4 ], [32 , 32 , 32 ])
1178+ ttgl .warp_specialize ([
1179+ (ws_default , (smem , bar , MISSING_BAR , blocked_layout )),
1180+ (ws_1 , (smem , bar , MISSING_BAR , blocked_layout )),
1181+ (ws_2 , (smem , bar , MISSING_BAR , blocked_layout )),
1182+ (ws_3 , (smem , bar , MISSING_BAR , blocked_layout )),
1183+ ], [4 , 4 , 4 ], [32 , 32 , 32 ])
11651184
11661185 output = torch .empty ((XBLOCK , ), device = device , dtype = torch .float16 )
11671186 kernel [(1 , )](output , MISSING_BAR = MISSING_BAR , num_warps = 4 )
@@ -1225,8 +1244,11 @@ def kernel(output, MISSING_BAR: ttgl.constexpr):
12251244 bar = ttgl .allocate_shared_memory (ttgl .int64 , [3 , 1 ], mbarrier .MBarrierLayout ())
12261245 for i in range (3 ):
12271246 mbarrier .init (bar .index (i ), count = 1 )
1228- ttgl .warp_specialize ((smem , bar , MISSING_BAR ), ws_default , (smem , bar , MISSING_BAR ), [ws_1 , ws_2 ], [2 , 8 ],
1229- [32 , 32 ])
1247+ ttgl .warp_specialize ([
1248+ (ws_default , (smem , bar , MISSING_BAR )),
1249+ (ws_1 , (smem , bar , MISSING_BAR )),
1250+ (ws_2 , (smem , bar , MISSING_BAR )),
1251+ ], [2 , 8 ], [32 , 32 ])
12301252 mbarrier .wait (bar .index (2 ), phase = 0 )
12311253 val = smem .index (0 ).load (blocked_layout )
12321254 output_ptrs = output + ttgl .arange (0 , XBLOCK , blocked_layout )
@@ -1291,8 +1313,10 @@ def kernel(input, FAILURE: ttgl.constexpr):
12911313 smem = ttgl .allocate_shared_memory (ttgl .float16 , [4 , XBLOCK ], smem_layout )
12921314 blocked_layout : ttgl .constexpr = ttgl .BlockedLayout (size_per_thread = [XBLOCK ], threads_per_warp = [32 ],
12931315 warps_per_cta = [4 ], order = [0 ])
1294- ttgl .warp_specialize ((input , smem , FAILURE , blocked_layout , 0 ), ws_prog ,
1295- (input , smem , FAILURE , blocked_layout , 2 ), [ws_prog ], [4 ], [32 ])
1316+ ttgl .warp_specialize ([
1317+ (ws_prog , (input , smem , FAILURE , blocked_layout , 0 )),
1318+ (ws_prog , (input , smem , FAILURE , blocked_layout , 2 )),
1319+ ], [4 ], [32 ])
12961320
12971321 input = torch .randn ((XBLOCK , ), device = device , dtype = torch .float16 )
12981322 kernel [(1 , )](input , FAILURE = FAILURE , num_warps = 4 )
@@ -1346,8 +1370,10 @@ def kernel(input, FAILURE: ttgl.constexpr):
13461370 smem = ttgl .allocate_shared_memory (ttgl .float16 , [2 , XBLOCK ], smem_layout )
13471371 bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 , 1 ], mbarrier .MBarrierLayout ())
13481372 mbarrier .init (bar .index (0 ), count = 1 )
1349- ttgl .warp_specialize ((input , smem , bar , FAILURE , blocked_layout ), ws_default ,
1350- (input , smem , bar , FAILURE , blocked_layout ), [ws_1 ], [4 ], [32 ])
1373+ ttgl .warp_specialize ([
1374+ (ws_default , (input , smem , bar , FAILURE , blocked_layout )),
1375+ (ws_1 , (input , smem , bar , FAILURE , blocked_layout )),
1376+ ], [4 ], [32 ])
13511377
13521378 input = torch .randn ((XBLOCK , ), device = device , dtype = torch .float16 )
13531379 kernel [(1 , )](input , FAILURE = FAILURE , num_warps = 4 )
@@ -1402,8 +1428,10 @@ def kernel(FAILURE: ttgl.constexpr):
14021428 smem = ttgl .allocate_shared_memory (ttgl .float16 , [2 , XBLOCK , XBLOCK ], smem_layout )
14031429 bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 , 1 ], mbarrier .MBarrierLayout ())
14041430 mbarrier .init (bar .index (0 ), count = 1 )
1405- ttgl .warp_specialize ((smem , bar , FAILURE , blocked_layout , mma_layout ), ws_default ,
1406- (smem , bar , FAILURE , blocked_layout ), [ws_1 ], [4 ], [32 ])
1431+ ttgl .warp_specialize ([
1432+ (ws_default , (smem , bar , FAILURE , blocked_layout , mma_layout )),
1433+ (ws_1 , (smem , bar , FAILURE , blocked_layout )),
1434+ ], [4 ], [32 ])
14071435
14081436 kernel [(1 , )](FAILURE = FAILURE , num_warps = 4 )
14091437
@@ -1438,7 +1466,10 @@ def kernel():
14381466 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
14391467 mbarrier .init (bar .index (0 ), count = 1 )
14401468 mbarrier .init (bar .index (1 ), count = 1 )
1441- ttgl .warp_specialize ((bar , ), ws_default , (bar , ), [ws_1 ], [4 ], [32 ])
1469+ ttgl .warp_specialize ([
1470+ (ws_default , (bar , )),
1471+ (ws_1 , (bar , )),
1472+ ], [4 ], [32 ])
14421473
14431474 kernel [(1 , )](num_warps = 4 )
14441475
@@ -1505,7 +1536,10 @@ def kernel():
15051536 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
15061537 mbarrier .init (bar .index (0 ), count = 2 )
15071538 mbarrier .init (bar .index (1 ), count = 2 )
1508- ttgl .warp_specialize ((bar , ), ws_default , (bar , ), [ws_1 ], [4 ], [32 ])
1539+ ttgl .warp_specialize ([
1540+ (ws_default , (bar , )),
1541+ (ws_1 , (bar , )),
1542+ ], [4 ], [32 ])
15091543
15101544 kernel [(1 , )](num_warps = 4 )
15111545
@@ -1541,7 +1575,10 @@ def kernel():
15411575 bar = ttgl .allocate_shared_memory (ttgl .int64 , [1 , 1 ], mbarrier .MBarrierLayout ())
15421576 mbarrier .init (bar .index (0 ), count = 1 )
15431577 mbarrier .arrive (bar .index (0 ), count = 1 )
1544- ttgl .warp_specialize ((bar , ), ws_default , (bar , ), [ws_1 ], [4 ], [32 ])
1578+ ttgl .warp_specialize ([
1579+ (ws_default , (bar , )),
1580+ (ws_1 , (bar , )),
1581+ ], [4 ], [32 ])
15451582
15461583 kernel [(1 , )](num_warps = 4 )
15471584
@@ -1582,7 +1619,10 @@ def kernel(input_desc):
15821619 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
15831620 mbarrier .init (bar .index (0 ), count = 1 )
15841621 mbarrier .init (bar .index (1 ), count = 1 )
1585- ttgl .warp_specialize ((input_desc , smem , bar ), ws_default , (input_desc , smem , bar ), [ws_1 ], [4 ], [32 ])
1622+ ttgl .warp_specialize ([
1623+ (ws_default , (input_desc , smem , bar )),
1624+ (ws_1 , (input_desc , smem , bar )),
1625+ ], [4 ], [32 ])
15861626
15871627 input = torch .randn ((XBLOCK , XBLOCK ), device = device , dtype = torch .float16 )
15881628 shared_layout = ttgl .NVMMASharedLayout (swizzle_byte_width = 128 , element_bitwidth = 16 , rank = 2 )
@@ -1621,6 +1661,9 @@ def kernel():
16211661 bar = ttgl .allocate_shared_memory (ttgl .int64 , [2 , 1 ], mbarrier .MBarrierLayout ())
16221662 mbarrier .init (bar .index (0 ), count = 1 )
16231663 mbarrier .init (bar .index (1 ), count = 1 )
1624- ttgl .warp_specialize ((bar , ), ws_default , (bar , ), [ws_1 ], [4 ], [32 ])
1664+ ttgl .warp_specialize ([
1665+ (ws_default , (bar , )),
1666+ (ws_1 , (bar , )),
1667+ ], [4 ], [32 ])
16251668
16261669 kernel [(1 , )](num_warps = 4 )
0 commit comments