Skip to content

Commit 5a836a4

Browse files
zongfeijingvideodanchik
authored andcommitted
[None] [feat] Add test script and raster M for gather fc1 kernel (NVIDIA#10429)
Signed-off-by: Zongfei Jing <[email protected]> Signed-off-by: Daniil Kulko <[email protected]>
1 parent ec75bce commit 5a836a4

File tree

3 files changed

+1555
-31
lines changed

3 files changed

+1555
-31
lines changed

tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,10 +1864,13 @@ def get_valid_tactics(
18641864
mma_tiler_mn_candidates = [(self.tile_size, 128),
18651865
(self.tile_size, 256)]
18661866
cluster_shape_mn_candidates = [(self.tile_size // 128, 1)]
1867+
# TODO: Add raster_along_m=True if we find it more performant in some cases.
1868+
raster_along_m_candidates = [False]
18671869

18681870
valid_tactics = []
1869-
for mma_tiler_mn, cluster_shape_mn in itertools.product(
1870-
mma_tiler_mn_candidates, cluster_shape_mn_candidates):
1871+
for mma_tiler_mn, cluster_shape_mn, raster_along_m in itertools.product(
1872+
mma_tiler_mn_candidates, cluster_shape_mn_candidates,
1873+
raster_along_m_candidates):
18711874
if self.__class__.kernel_class.can_implement(
18721875
ab_dtype=cutlass.Float4E2M1FN,
18731876
sf_dtype=cutlass.Float8E4M3FN,
@@ -1883,7 +1886,8 @@ def get_valid_tactics(
18831886
b_major="k",
18841887
c_major="n",
18851888
):
1886-
valid_tactics.append((mma_tiler_mn, cluster_shape_mn))
1889+
valid_tactics.append(
1890+
(mma_tiler_mn, cluster_shape_mn, raster_along_m))
18871891

18881892
return valid_tactics
18891893

@@ -2013,22 +2017,24 @@ def forward(self, inputs: List[torch.Tensor],
20132017
stream = cuda.CUstream(torch_stream.cuda_stream)
20142018

20152019
if isinstance(tactic, tuple):
2016-
mma_tiler_mn, cluster_shape_mn = tactic
2020+
mma_tiler_mn, cluster_shape_mn, raster_along_m = tactic
20172021
else:
20182022
mma_tiler_mn = (self.tile_size, 128)
20192023
cluster_shape_mn = (self.tile_size // 128, 1)
2024+
raster_along_m = False
20202025
assert mma_tiler_mn[
20212026
0] == self.tile_size, f"Tactic ({tactic}) is incompatible with tile size ({self.tile_size})"
20222027

20232028
cache_key = (self.scaling_vector_size, self.tile_size, self.top_k,
2024-
mma_tiler_mn, cluster_shape_mn)
2029+
mma_tiler_mn, cluster_shape_mn, raster_along_m)
20252030
if cache_key not in self.__class__.kernel_cache:
20262031
gemm = self.__class__.kernel_class(
20272032
sf_vec_size=self.scaling_vector_size,
20282033
mma_tiler_mn=mma_tiler_mn,
20292034
cluster_shape_mn=cluster_shape_mn,
20302035
vectorized_f32=True,
20312036
topk=self.top_k,
2037+
raster_along_m=raster_along_m,
20322038
)
20332039
# Compute max active clusters on current device
20342040
hardware_info = cutlass.utils.HardwareInfo()

tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockscaled_contiguous_gather_grouped_gemm_swiglu_fusion.py

Lines changed: 239 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import cutlass.utils.blockscaled_layout as blockscaled_utils
3838
from cutlass._mlir.dialects import math
3939
from cutlass.cute.nvgpu import cpasync, tcgen05
40+
from cutlass.cutlass_dsl import Int32
4041

4142
from .custom_pipeline import PipelineCpAsyncUmma
4243
from .utils import (
@@ -154,6 +155,144 @@
154155
"""
155156

156157

158+
# TODO: Remove this hook helper function after nvidia-cutlass-dsl 4.4 is released.
159+
def hooked_PersistentTileSchedulerParams_init(
160+
self,
161+
problem_shape_ntile_mnl: cute.Shape,
162+
cluster_shape_mnk: cute.Shape,
163+
swizzle_size: int = 1,
164+
raster_along_m: bool = True,
165+
*,
166+
loc=None,
167+
ip=None,
168+
):
169+
if cluster_shape_mnk[2] != 1:
170+
raise ValueError(f"unsupported cluster_shape_k {cluster_shape_mnk[2]}")
171+
if swizzle_size < 1:
172+
raise ValueError(f"expect swizzle_size >= 1, but get {swizzle_size}")
173+
174+
self.problem_shape_ntile_mnl = problem_shape_ntile_mnl
175+
# cluster_shape_mnk is kept for reconstruction
176+
self._cluster_shape_mnk = cluster_shape_mnk
177+
self.cluster_shape_mn = cluster_shape_mnk[:2]
178+
self.swizzle_size = swizzle_size
179+
self._raster_along_m = raster_along_m
180+
self._loc = loc
181+
182+
# Apply swizzle if swizzle_size > 1
183+
if swizzle_size > 1:
184+
problem_shape_ncluster_mnl = cute.round_up(
185+
self.problem_layout_ncluster_mnl.shape,
186+
(1, swizzle_size, 1) if raster_along_m else (swizzle_size, 1, 1),
187+
)
188+
189+
if raster_along_m:
190+
self.problem_layout_ncluster_mnl = cute.make_layout(
191+
(
192+
problem_shape_ncluster_mnl[0],
193+
(swizzle_size, problem_shape_ncluster_mnl[1] // swizzle_size),
194+
problem_shape_ncluster_mnl[2],
195+
),
196+
stride=(
197+
swizzle_size,
198+
(1, swizzle_size * problem_shape_ncluster_mnl[0]),
199+
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
200+
),
201+
loc=loc,
202+
ip=ip,
203+
)
204+
else:
205+
self.problem_layout_ncluster_mnl = cute.make_layout(
206+
(
207+
(swizzle_size, problem_shape_ncluster_mnl[0] // swizzle_size),
208+
problem_shape_ncluster_mnl[1],
209+
problem_shape_ncluster_mnl[2],
210+
),
211+
stride=(
212+
(1, swizzle_size * problem_shape_ncluster_mnl[1]),
213+
swizzle_size,
214+
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
215+
),
216+
loc=loc,
217+
ip=ip,
218+
)
219+
220+
# Create FastDivmod divisors (only when swizzle_size == 1 for correctness)
221+
# FastDivmod assumes simple col-major/row-major layout, incompatible with swizzled layouts
222+
if swizzle_size == 1:
223+
problem_shape_ncluster_mnl = cute.ceil_div(
224+
self.problem_shape_ntile_mnl, cluster_shape_mnk[:2], loc=loc, ip=ip
225+
)
226+
if raster_along_m:
227+
self.problem_layout_ncluster_mnl = cute.make_layout(
228+
problem_shape_ncluster_mnl,
229+
stride=(
230+
1,
231+
problem_shape_ncluster_mnl[0],
232+
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
233+
),
234+
loc=loc,
235+
ip=ip,
236+
)
237+
else:
238+
self.problem_layout_ncluster_mnl = cute.make_layout(
239+
problem_shape_ncluster_mnl,
240+
stride=(
241+
problem_shape_ncluster_mnl[1],
242+
1,
243+
problem_shape_ncluster_mnl[0] * problem_shape_ncluster_mnl[1],
244+
),
245+
loc=loc,
246+
ip=ip,
247+
)
248+
problem_layout_size = cute.size(self.problem_layout_ncluster_mnl, loc=loc, ip=ip)
249+
cluster_count_m = self.problem_layout_ncluster_mnl.shape[0]
250+
cluster_count_n = self.problem_layout_ncluster_mnl.shape[1]
251+
252+
# batch_fdd: Used to map linear_idx to work_unit_id (handles persistent scheduling)
253+
self.batch_fdd = cute.fast_divmod_create_divisor(problem_layout_size, loc=loc, ip=ip)
254+
255+
# cluster_shape_m_fdd: Used to decode work_unit_id to cluster coordinates
256+
self.cluster_shape_m_fdd = cute.fast_divmod_create_divisor(cluster_count_m, loc=loc, ip=ip)
257+
258+
# cluster_shape_n_fdd: Used for the second level decomposition
259+
self.cluster_shape_n_fdd = cute.fast_divmod_create_divisor(cluster_count_n, loc=loc, ip=ip)
260+
else:
261+
# FastDivmod not applicable with swizzling, set to None
262+
self.batch_fdd = None
263+
self.cluster_shape_m_fdd = None
264+
self.cluster_shape_n_fdd = None
265+
266+
267+
def hooked_get_cluster_work_idx_with_fastdivmod(
268+
self, current_work_linear_idx: Int32, *, loc=None, ip=None
269+
) -> Tuple[Int32, Int32, Int32]:
270+
work_iteration, work_unit_id = divmod(current_work_linear_idx, self.params.batch_fdd)
271+
272+
if self.params._raster_along_m:
273+
# raster_along_m=True means column major (m is fastest)
274+
# First, get cluster_m using cluster_shape_m_fdd
275+
cluster_n_batch, cluster_m = divmod(work_unit_id, self.params.cluster_shape_m_fdd)
276+
277+
# Then decode cluster_n_batch to get cluster_n and batch_l using FastDivmod
278+
batch_l, cluster_n = divmod(cluster_n_batch, self.params.cluster_shape_n_fdd)
279+
else:
280+
# raster_along_m=False means row major (n is fastest)
281+
# First, get cluster_n using cluster_shape_n_fdd
282+
cluster_m_batch, cluster_n = divmod(work_unit_id, self.params.cluster_shape_n_fdd)
283+
284+
# Then decode cluster_m_batch to get cluster_m and batch_l using FastDivmod
285+
batch_l, cluster_m = divmod(cluster_m_batch, self.params.cluster_shape_m_fdd)
286+
287+
return (cluster_m, cluster_n, batch_l)
288+
289+
290+
cutlass.utils.PersistentTileSchedulerParams.__init__ = hooked_PersistentTileSchedulerParams_init
291+
cutlass.utils.StaticPersistentTileScheduler._get_cluster_work_idx_with_fastdivmod = (
292+
hooked_get_cluster_work_idx_with_fastdivmod
293+
)
294+
295+
157296
class BlockScaledContiguousGatherGroupedGemmKernel:
158297
"""This class implements contiguous grouped matrix multiplication with gather operation and SwiGLU fusion
159298
for FC1 layer computation (C = up * silu(gate), where up/gate come from interleaved GEMM result).
@@ -245,6 +384,7 @@ def __init__(
245384
cluster_shape_mn: Tuple[int, int],
246385
vectorized_f32: bool,
247386
topk: cutlass.Int64,
387+
raster_along_m: bool = False,
248388
):
249389
"""Initializes the configuration for a Blackwell blockscaled dense GEMM kernel with
250390
gather operation and SwiGLU fusion.
@@ -289,6 +429,7 @@ def __init__(
289429
self.cluster_shape_mn = cluster_shape_mn
290430
# K dimension is deferred in _setup_attributes
291431
self.mma_tiler = (*mma_tiler_mn, 1)
432+
self.raster_along_m = raster_along_m
292433

293434
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
294435

@@ -743,7 +884,11 @@ def __call__(
743884

744885
# Compute grid size
745886
self.tile_sched_params, grid = self._compute_grid(
746-
c, self.cta_tile_shape_mnk_c, self.cluster_shape_mn, max_active_clusters
887+
c,
888+
self.cta_tile_shape_mnk_c,
889+
self.cluster_shape_mn,
890+
max_active_clusters,
891+
self.raster_along_m,
747892
)
748893

749894
self.buffer_align_bytes = 1024
@@ -1254,34 +1399,69 @@ def kernel(
12541399
pipeline.PipelineUserType.Producer, self.num_tile_stage
12551400
)
12561401

1257-
while work_tile.is_valid_tile:
1258-
cur_tile_coord = work_tile.tile_idx
1259-
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
1260-
if mma_tile_coord_m < num_non_exiting_tiles[0]:
1261-
tile_info_pipeline.producer_acquire(tile_info_producer_state)
1402+
num_non_exiting_tiles_value = num_non_exiting_tiles[0]
1403+
1404+
if cutlass.const_expr(self.raster_along_m):
1405+
while work_tile.is_valid_tile:
12621406
cur_tile_coord = work_tile.tile_idx
1263-
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
1264-
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
1265-
with cute.arch.elect_one():
1266-
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
1267-
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
1268-
sInfo[(2, tile_info_producer_state.index)] = expert_idx
1269-
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
1270-
work_tile.is_valid_tile
1407+
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
1408+
if mma_tile_coord_m < num_non_exiting_tiles_value:
1409+
tile_info_pipeline.producer_acquire(tile_info_producer_state)
1410+
cur_tile_coord = work_tile.tile_idx
1411+
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
1412+
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
1413+
with cute.arch.elect_one():
1414+
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
1415+
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
1416+
sInfo[(2, tile_info_producer_state.index)] = expert_idx
1417+
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
1418+
work_tile.is_valid_tile
1419+
)
1420+
sInfo[(4, tile_info_producer_state.index)] = mn_limit
1421+
# fence view async shared
1422+
cute.arch.fence_proxy(
1423+
cute.arch.ProxyKind.async_shared,
1424+
space=cute.arch.SharedSpace.shared_cta,
12711425
)
1272-
sInfo[(4, tile_info_producer_state.index)] = mn_limit
1273-
# fence view async shared
1274-
cute.arch.fence_proxy(
1275-
cute.arch.ProxyKind.async_shared,
1276-
space=cute.arch.SharedSpace.shared_cta,
1277-
)
12781426

1279-
self.sched_sync_barrier.arrive_and_wait()
1280-
tile_info_pipeline.producer_commit(tile_info_producer_state)
1281-
tile_info_producer_state.advance()
1427+
self.sched_sync_barrier.arrive_and_wait()
1428+
tile_info_pipeline.producer_commit(tile_info_producer_state)
1429+
tile_info_producer_state.advance()
12821430

1283-
tile_sched.advance_to_next_work()
1284-
work_tile = tile_sched.get_current_work()
1431+
tile_sched.advance_to_next_work()
1432+
work_tile = tile_sched.get_current_work()
1433+
else:
1434+
is_continue = cutlass.Boolean(1)
1435+
while work_tile.is_valid_tile and is_continue:
1436+
cur_tile_coord = work_tile.tile_idx
1437+
mma_tile_coord_m = cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape)
1438+
if mma_tile_coord_m < num_non_exiting_tiles_value:
1439+
tile_info_pipeline.producer_acquire(tile_info_producer_state)
1440+
cur_tile_coord = work_tile.tile_idx
1441+
expert_idx = tile_idx_to_expert_idx[mma_tile_coord_m]
1442+
mn_limit = tile_idx_to_mn_limit[mma_tile_coord_m]
1443+
with cute.arch.elect_one():
1444+
sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0]
1445+
sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1]
1446+
sInfo[(2, tile_info_producer_state.index)] = expert_idx
1447+
sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32(
1448+
work_tile.is_valid_tile
1449+
)
1450+
sInfo[(4, tile_info_producer_state.index)] = mn_limit
1451+
# fence view async shared
1452+
cute.arch.fence_proxy(
1453+
cute.arch.ProxyKind.async_shared,
1454+
space=cute.arch.SharedSpace.shared_cta,
1455+
)
1456+
1457+
self.sched_sync_barrier.arrive_and_wait()
1458+
tile_info_pipeline.producer_commit(tile_info_producer_state)
1459+
tile_info_producer_state.advance()
1460+
else:
1461+
is_continue = cutlass.Boolean(0)
1462+
1463+
tile_sched.advance_to_next_work()
1464+
work_tile = tile_sched.get_current_work()
12851465

12861466
tile_info_pipeline.producer_acquire(tile_info_producer_state)
12871467
with cute.arch.elect_one():
@@ -2781,6 +2961,7 @@ def _compute_grid(
27812961
cta_tile_shape_mnk: Tuple[int, int, int],
27822962
cluster_shape_mn: Tuple[int, int],
27832963
max_active_clusters: cutlass.Constexpr,
2964+
raster_along_m: bool = False,
27842965
) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]:
27852966
"""Use persistent tile scheduler to compute the grid size for the output tensor C.
27862967
@@ -2803,7 +2984,9 @@ def _compute_grid(
28032984
num_ctas_mnl = gc[(0, (None, None, None))].shape
28042985
cluster_shape_mnl = (*cluster_shape_mn, 1)
28052986

2806-
tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl)
2987+
tile_sched_params = utils.PersistentTileSchedulerParams(
2988+
num_ctas_mnl, cluster_shape_mnl, raster_along_m=raster_along_m
2989+
)
28072990
grid = utils.StaticPersistentTileScheduler.get_grid_shape(
28082991
tile_sched_params, max_active_clusters
28092992
)
@@ -3209,3 +3392,33 @@ def wrapper(
32093392
stream=stream,
32103393
epilogue_op=epilogue_op,
32113394
)
3395+
3396+
3397+
@cute.jit
3398+
def cvt_sf_MKL_to_M32x4xrm_K4xrk_L(
3399+
sf_ref_tensor: cute.Tensor,
3400+
sf_mma_tensor: cute.Tensor,
3401+
):
3402+
"""Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout"""
3403+
# sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
3404+
# group to ((32, 4, rest_m), (4, rest_k), l)
3405+
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3)
3406+
sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3)
3407+
for i in cutlass.range(cute.size(sf_ref_tensor)):
3408+
mkl_coord = sf_ref_tensor.layout.get_hier_coord(i)
3409+
sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord]
3410+
3411+
3412+
@cute.jit
3413+
def cvt_sf_M32x4xrm_K4xrk_L_to_MKL(
3414+
sf_swizzled_tensor: cute.Tensor,
3415+
sf_unswizzled_tensor: cute.Tensor,
3416+
):
3417+
"""Convert scale factor tensor from mma specification M(32x4xrest_m)xK(4xrest_k)xL layout to MKL layout"""
3418+
# sf_swizzled_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l)
3419+
# group to ((32, 4, rest_m), (4, rest_k), l)
3420+
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 0, 3)
3421+
sf_swizzled_tensor = cute.group_modes(sf_swizzled_tensor, 1, 3)
3422+
for i in cutlass.range(cute.size(sf_unswizzled_tensor)):
3423+
mkl_coord = sf_unswizzled_tensor.layout.get_hier_coord(i)
3424+
sf_unswizzled_tensor[mkl_coord] = sf_swizzled_tensor[mkl_coord]

0 commit comments

Comments
 (0)