3737import cutlass .utils .blockscaled_layout as blockscaled_utils
3838from cutlass ._mlir .dialects import math
3939from cutlass .cute .nvgpu import cpasync , tcgen05
40+ from cutlass .cutlass_dsl import Int32
4041
4142from .custom_pipeline import PipelineCpAsyncUmma
4243from .utils import (
154155"""
155156
156157
158+ # TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released.
159+ def hooked_PersistentTileSchedulerParams_init (
160+ self ,
161+ problem_shape_ntile_mnl : cute .Shape ,
162+ cluster_shape_mnk : cute .Shape ,
163+ swizzle_size : int = 1 ,
164+ raster_along_m : bool = True ,
165+ * ,
166+ loc = None ,
167+ ip = None ,
168+ ):
169+ if cluster_shape_mnk [2 ] != 1 :
170+ raise ValueError (f"unsupported cluster_shape_k { cluster_shape_mnk [2 ]} " )
171+ if swizzle_size < 1 :
172+ raise ValueError (f"expect swizzle_size >= 1, but get { swizzle_size } " )
173+
174+ self .problem_shape_ntile_mnl = problem_shape_ntile_mnl
175+ # cluster_shape_mnk is kept for reconstruction
176+ self ._cluster_shape_mnk = cluster_shape_mnk
177+ self .cluster_shape_mn = cluster_shape_mnk [:2 ]
178+ self .swizzle_size = swizzle_size
179+ self ._raster_along_m = raster_along_m
180+ self ._loc = loc
181+
182+ # Apply swizzle if swizzle_size > 1
183+ if swizzle_size > 1 :
184+ problem_shape_ncluster_mnl = cute .round_up (
185+ self .problem_layout_ncluster_mnl .shape ,
186+ (1 , swizzle_size , 1 ) if raster_along_m else (swizzle_size , 1 , 1 ),
187+ )
188+
189+ if raster_along_m :
190+ self .problem_layout_ncluster_mnl = cute .make_layout (
191+ (
192+ problem_shape_ncluster_mnl [0 ],
193+ (swizzle_size , problem_shape_ncluster_mnl [1 ] // swizzle_size ),
194+ problem_shape_ncluster_mnl [2 ],
195+ ),
196+ stride = (
197+ swizzle_size ,
198+ (1 , swizzle_size * problem_shape_ncluster_mnl [0 ]),
199+ problem_shape_ncluster_mnl [0 ] * problem_shape_ncluster_mnl [1 ],
200+ ),
201+ loc = loc ,
202+ ip = ip ,
203+ )
204+ else :
205+ self .problem_layout_ncluster_mnl = cute .make_layout (
206+ (
207+ (swizzle_size , problem_shape_ncluster_mnl [0 ] // swizzle_size ),
208+ problem_shape_ncluster_mnl [1 ],
209+ problem_shape_ncluster_mnl [2 ],
210+ ),
211+ stride = (
212+ (1 , swizzle_size * problem_shape_ncluster_mnl [1 ]),
213+ swizzle_size ,
214+ problem_shape_ncluster_mnl [0 ] * problem_shape_ncluster_mnl [1 ],
215+ ),
216+ loc = loc ,
217+ ip = ip ,
218+ )
219+
220+ # Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
221+ # FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts
222+ if swizzle_size == 1 :
223+ problem_shape_ncluster_mnl = cute .ceil_div (
224+ self .problem_shape_ntile_mnl , cluster_shape_mnk [:2 ], loc = loc , ip = ip
225+ )
226+ if raster_along_m :
227+ self .problem_layout_ncluster_mnl = cute .make_layout (
228+ problem_shape_ncluster_mnl ,
229+ stride = (
230+ 1 ,
231+ problem_shape_ncluster_mnl [0 ],
232+ problem_shape_ncluster_mnl [0 ] * problem_shape_ncluster_mnl [1 ],
233+ ),
234+ loc = loc ,
235+ ip = ip ,
236+ )
237+ else :
238+ self .problem_layout_ncluster_mnl = cute .make_layout (
239+ problem_shape_ncluster_mnl ,
240+ stride = (
241+ problem_shape_ncluster_mnl [1 ],
242+ 1 ,
243+ problem_shape_ncluster_mnl [0 ] * problem_shape_ncluster_mnl [1 ],
244+ ),
245+ loc = loc ,
246+ ip = ip ,
247+ )
248+ problem_layout_size = cute .size (self .problem_layout_ncluster_mnl , loc = loc , ip = ip )
249+ cluster_count_m = self .problem_layout_ncluster_mnl .shape [0 ]
250+ cluster_count_n = self .problem_layout_ncluster_mnl .shape [1 ]
251+
252+ # batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
253+ self .batch_fdd = cute .fast_divmod_create_divisor (problem_layout_size , loc = loc , ip = ip )
254+
255+ # cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates
256+ self .cluster_shape_m_fdd = cute .fast_divmod_create_divisor (cluster_count_m , loc = loc , ip = ip )
257+
258+ # cluster_shape_n_fdd: Used for the second level decomposition
259+ self .cluster_shape_n_fdd = cute .fast_divmod_create_divisor (cluster_count_n , loc = loc , ip = ip )
260+ else :
261+ # FastDivmod not applicable with swizzling, set to None
262+ self .batch_fdd = None
263+ self .cluster_shape_m_fdd = None
264+ self .cluster_shape_n_fdd = None
265+
266+
267+ def hooked_get_cluster_work_idx_with_fastdivmod (
268+ self , current_work_linear_idx : Int32 , * , loc = None , ip = None
269+ ) -> Tuple [Int32 , Int32 , Int32 ]:
270+ work_iteration , work_unit_id = divmod (current_work_linear_idx , self .params .batch_fdd )
271+
272+ if self .params ._raster_along_m :
273+ # raster_along_m=True means column major (m is fastest)
274+ # First, get cluster_m using cluster_shape_m_fdd
275+ cluster_n_batch , cluster_m = divmod (work_unit_id , self .params .cluster_shape_m_fdd )
276+
277+ # Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod
278+ batch_l , cluster_n = divmod (cluster_n_batch , self .params .cluster_shape_n_fdd )
279+ else :
280+ # raster_along_m=False means row major (n is fastest)
281+ # First, get cluster_n using cluster_shape_n_fdd
282+ cluster_m_batch , cluster_n = divmod (work_unit_id , self .params .cluster_shape_n_fdd )
283+
284+ # Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod
285+ batch_l , cluster_m = divmod (cluster_m_batch , self .params .cluster_shape_m_fdd )
286+
287+ return (cluster_m , cluster_n , batch_l )
288+
289+
290+ cutlass .utils .PersistentTileSchedulerParams .__init__ = hooked_PersistentTileSchedulerParams_init
291+ cutlass .utils .StaticPersistentTileScheduler ._get_cluster_work_idx_with_fastdivmod = (
292+ hooked_get_cluster_work_idx_with_fastdivmod
293+ )
294+
295+
157296class BlockScaledContiguousGatherGroupedGemmKernel :
158297 """This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion
159298 for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result).
@@ -245,6 +384,7 @@ def __init__(
245384 cluster_shape_mn : Tuple [int , int ],
246385 vectorized_f32 : bool ,
247386 topk : cutlass .Int64 ,
387+ raster_along_m : bool = False ,
248388 ):
249389 """Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
250390 gather operation and SwiGLU fusion.
@@ -289,6 +429,7 @@ def __init__(
289429 self .cluster_shape_mn = cluster_shape_mn
290430 # K dimension is deferred in _setup_attributes
291431 self .mma_tiler = (* mma_tiler_mn , 1 )
432+ self .raster_along_m = raster_along_m
292433
293434 self .cta_group = tcgen05 .CtaGroup .TWO if self .use_2cta_instrs else tcgen05 .CtaGroup .ONE
294435
@@ -743,7 +884,11 @@ def __call__(
743884
744885 # Compute grid size
745886 self .tile_sched_params , grid = self ._compute_grid (
746- c , self .cta_tile_shape_mnk_c , self .cluster_shape_mn , max_active_clusters
887+ c ,
888+ self .cta_tile_shape_mnk_c ,
889+ self .cluster_shape_mn ,
890+ max_active_clusters ,
891+ self .raster_along_m ,
747892 )
748893
749894 self .buffer_align_bytes = 1024
@@ -1254,34 +1399,69 @@ def kernel(
12541399 pipeline .PipelineUserType .Producer , self .num_tile_stage
12551400 )
12561401
1257- while work_tile .is_valid_tile :
1258- cur_tile_coord = work_tile .tile_idx
1259- mma_tile_coord_m = cur_tile_coord [0 ] // cute .size (tiled_mma .thr_id .shape )
1260- if mma_tile_coord_m < num_non_exiting_tiles [0 ]:
1261- tile_info_pipeline .producer_acquire (tile_info_producer_state )
1402+ num_non_exiting_tiles_value = num_non_exiting_tiles [0 ]
1403+
1404+ if cutlass .const_expr (self .raster_along_m ):
1405+ while work_tile .is_valid_tile :
12621406 cur_tile_coord = work_tile .tile_idx
1263- expert_idx = tile_idx_to_expert_idx [mma_tile_coord_m ]
1264- mn_limit = tile_idx_to_mn_limit [mma_tile_coord_m ]
1265- with cute .arch .elect_one ():
1266- sInfo [(0 , tile_info_producer_state .index )] = cur_tile_coord [0 ]
1267- sInfo [(1 , tile_info_producer_state .index )] = cur_tile_coord [1 ]
1268- sInfo [(2 , tile_info_producer_state .index )] = expert_idx
1269- sInfo [(3 , tile_info_producer_state .index )] = cutlass .Int32 (
1270- work_tile .is_valid_tile
1407+ mma_tile_coord_m = cur_tile_coord [0 ] // cute .size (tiled_mma .thr_id .shape )
1408+ if mma_tile_coord_m < num_non_exiting_tiles_value :
1409+ tile_info_pipeline .producer_acquire (tile_info_producer_state )
1410+ cur_tile_coord = work_tile .tile_idx
1411+ expert_idx = tile_idx_to_expert_idx [mma_tile_coord_m ]
1412+ mn_limit = tile_idx_to_mn_limit [mma_tile_coord_m ]
1413+ with cute .arch .elect_one ():
1414+ sInfo [(0 , tile_info_producer_state .index )] = cur_tile_coord [0 ]
1415+ sInfo [(1 , tile_info_producer_state .index )] = cur_tile_coord [1 ]
1416+ sInfo [(2 , tile_info_producer_state .index )] = expert_idx
1417+ sInfo [(3 , tile_info_producer_state .index )] = cutlass .Int32 (
1418+ work_tile .is_valid_tile
1419+ )
1420+ sInfo [(4 , tile_info_producer_state .index )] = mn_limit
1421+ # fence view async shared
1422+ cute .arch .fence_proxy (
1423+ cute .arch .ProxyKind .async_shared ,
1424+ space = cute .arch .SharedSpace .shared_cta ,
12711425 )
1272- sInfo [(4 , tile_info_producer_state .index )] = mn_limit
1273- # fence view async shared
1274- cute .arch .fence_proxy (
1275- cute .arch .ProxyKind .async_shared ,
1276- space = cute .arch .SharedSpace .shared_cta ,
1277- )
12781426
1279- self .sched_sync_barrier .arrive_and_wait ()
1280- tile_info_pipeline .producer_commit (tile_info_producer_state )
1281- tile_info_producer_state .advance ()
1427+ self .sched_sync_barrier .arrive_and_wait ()
1428+ tile_info_pipeline .producer_commit (tile_info_producer_state )
1429+ tile_info_producer_state .advance ()
12821430
1283- tile_sched .advance_to_next_work ()
1284- work_tile = tile_sched .get_current_work ()
1431+ tile_sched .advance_to_next_work ()
1432+ work_tile = tile_sched .get_current_work ()
1433+ else :
1434+ is_continue = cutlass .Boolean (1 )
1435+ while work_tile .is_valid_tile and is_continue :
1436+ cur_tile_coord = work_tile .tile_idx
1437+ mma_tile_coord_m = cur_tile_coord [0 ] // cute .size (tiled_mma .thr_id .shape )
1438+ if mma_tile_coord_m < num_non_exiting_tiles_value :
1439+ tile_info_pipeline .producer_acquire (tile_info_producer_state )
1440+ cur_tile_coord = work_tile .tile_idx
1441+ expert_idx = tile_idx_to_expert_idx [mma_tile_coord_m ]
1442+ mn_limit = tile_idx_to_mn_limit [mma_tile_coord_m ]
1443+ with cute .arch .elect_one ():
1444+ sInfo [(0 , tile_info_producer_state .index )] = cur_tile_coord [0 ]
1445+ sInfo [(1 , tile_info_producer_state .index )] = cur_tile_coord [1 ]
1446+ sInfo [(2 , tile_info_producer_state .index )] = expert_idx
1447+ sInfo [(3 , tile_info_producer_state .index )] = cutlass .Int32 (
1448+ work_tile .is_valid_tile
1449+ )
1450+ sInfo [(4 , tile_info_producer_state .index )] = mn_limit
1451+ # fence view async shared
1452+ cute .arch .fence_proxy (
1453+ cute .arch .ProxyKind .async_shared ,
1454+ space = cute .arch .SharedSpace .shared_cta ,
1455+ )
1456+
1457+ self .sched_sync_barrier .arrive_and_wait ()
1458+ tile_info_pipeline .producer_commit (tile_info_producer_state )
1459+ tile_info_producer_state .advance ()
1460+ else :
1461+ is_continue = cutlass .Boolean (0 )
1462+
1463+ tile_sched .advance_to_next_work ()
1464+ work_tile = tile_sched .get_current_work ()
12851465
12861466 tile_info_pipeline .producer_acquire (tile_info_producer_state )
12871467 with cute .arch .elect_one ():
@@ -2781,6 +2961,7 @@ def _compute_grid(
27812961 cta_tile_shape_mnk : Tuple [int , int , int ],
27822962 cluster_shape_mn : Tuple [int , int ],
27832963 max_active_clusters : cutlass .Constexpr ,
2964+ raster_along_m : bool = False ,
27842965 ) -> Tuple [utils .PersistentTileSchedulerParams , Tuple [int , int , int ]]:
27852966 """Use persistent tile scheduler to compute the grid size for the output tensor C.
27862967
@@ -2803,7 +2984,9 @@ def _compute_grid(
28032984 num_ctas_mnl = gc [(0 , (None , None , None ))].shape
28042985 cluster_shape_mnl = (* cluster_shape_mn , 1 )
28052986
2806- tile_sched_params = utils .PersistentTileSchedulerParams (num_ctas_mnl , cluster_shape_mnl )
2987+ tile_sched_params = utils .PersistentTileSchedulerParams (
2988+ num_ctas_mnl , cluster_shape_mnl , raster_along_m = raster_along_m
2989+ )
28072990 grid = utils .StaticPersistentTileScheduler .get_grid_shape (
28082991 tile_sched_params , max_active_clusters
28092992 )
@@ -3209,3 +3392,33 @@ def wrapper(
32093392 stream = stream ,
32103393 epilogue_op = epilogue_op ,
32113394 )
3395+
3396+
3397+ @cute .jit
3398+ def cvt_sf_MKL_to_M32x4xrm_K4xrk_L (
3399+ sf_ref_tensor : cute .Tensor ,
3400+ sf_mma_tensor : cute .Tensor ,
3401+ ):
3402+ """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
3403+ # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
3404+ # group to ((32, 4, rest_m), (4, rest_k), l)
3405+ sf_mma_tensor = cute .group_modes (sf_mma_tensor , 0 , 3 )
3406+ sf_mma_tensor = cute .group_modes (sf_mma_tensor , 1 , 3 )
3407+ for i in cutlass .range (cute .size (sf_ref_tensor )):
3408+ mkl_coord = sf_ref_tensor .layout .get_hier_coord (i )
3409+ sf_mma_tensor [mkl_coord ] = sf_ref_tensor [mkl_coord ]
3410+
3411+
3412+ @cute .jit
3413+ def cvt_sf_M32x4xrm_K4xrk_L_to_MKL (
3414+ sf_swizzled_tensor : cute .Tensor ,
3415+ sf_unswizzled_tensor : cute .Tensor ,
3416+ ):
3417+ """Convert scale factor tensor from mma specification M(32x4xrest_m)xK(4xrest_k)xL layout to MKL layout"""
3418+ # sf_swizzled_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
3419+ # group to ((32, 4, rest_m), (4, rest_k), l)
3420+ sf_swizzled_tensor = cute .group_modes (sf_swizzled_tensor , 0 , 3 )
3421+ sf_swizzled_tensor = cute .group_modes (sf_swizzled_tensor , 1 , 3 )
3422+ for i in cutlass .range (cute .size (sf_unswizzled_tensor )):
3423+ mkl_coord = sf_unswizzled_tensor .layout .get_hier_coord (i )
3424+ sf_unswizzled_tensor [mkl_coord ] = sf_swizzled_tensor [mkl_coord ]
0 commit comments