Skip to content

Commit 0766464

Browse files
authored
[KERNELS] Fix launch metadata computations for matmul_ogs. (#8429)
(Previous code was causing CUDA or python asserts for some cases.)
1 parent 7ec17cd commit 0766464

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

python/triton_kernels/triton_kernels/matmul_ogs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,10 @@ def matmul_ogs(x, w, bias,
467467
assert routing_data is None
468468
assert gather_indx is None
469469
assert scatter_indx is None
470-
routing_data = RoutingData(None, None, inner_routing_data.base.n_expts_tot, 1)
470+
routing_data = RoutingData(
471+
None, None, inner_routing_data.base.n_expts_tot, 1,
472+
expected_tokens_per_expt=inner_routing_data.base.expected_tokens_per_expt,
473+
)
471474
# canonicalize inputs
472475
if precision_config is None:
473476
precision_config = PrecisionConfig()
@@ -684,6 +687,7 @@ def matmul_ogs(x, w, bias,
684687
N, K, K_W,
685688
betas, gammas,
686689
None if gather_indx is None else gather_indx.src_indx,
690+
None if gather_indx is None else gather_indx.dst_indx, # Only for launch_metadata
687691
None if scatter_indx is None else scatter_indx.src_indx,
688692
num_indx,
689693
None if not opt_flags.fused_scatter else scatter_indx.dst_indx,

python/triton_kernels/triton_kernels/matmul_ogs_details/_common.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import triton
32
import triton.language as tl
43

@@ -221,7 +220,7 @@ def matmul_launch_metadata(grid, kernel, args):
221220
n_tokens = None
222221
n_w_bytes = W.numel() * W.element_size()
223222
if expt_is_inner:
224-
K = int(n_tokens)
223+
K = None if n_tokens is None else int(n_tokens)
225224
repr = lambda s, x: f"{s} = {x}" if x is not None else f"E_{len(hist)}({s}) = {n_rows}"
226225
nbits = X.dtype.itemsize * 8
227226
batch_repr = ""
@@ -238,20 +237,15 @@ def matmul_launch_metadata(grid, kernel, args):
238237
fM = M if M is not None else n_tokens
239238
ret[f"flops{nbits}"] = 2.0 * fM * N * K * (1 if expt_is_inner else batch_size)
240239

241-
gindx = args.get("GatherIndx", None)
240+
dst = args.get("GatherDstIndx", None)
242241
# sindx = args.get("WriteBackIndx", None)
243242
n_x_bytes = X.numel() * X.element_size()
244243
n_y_bytes = Y.numel() * Y.element_size()
245244
if hist is not None:
246245
assert n_tokens is not None
247246
n_expts_act = args["N_EXPTS_ACT"]
248247

249-
if (gindx is not None) and launch_metadata_allow_sync():
250-
# recreate inverse GatherIndx.
251-
dst = torch.full_like(gindx, -1)
252-
idx = torch.arange(len(gindx), device=gindx.device, dtype=torch.int32)
253-
mask = gindx != -1
254-
dst[gindx[mask]] = idx[mask]
248+
if (dst is not None) and launch_metadata_allow_sync():
255249
n_read_rows = (dst.view((-1, n_expts_act)) != -1).any(dim=1).sum()
256250
else:
257251
n_read_rows = n_tokens

python/triton_kernels/triton_kernels/matmul_ogs_details/_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _matmul_ogs(
5454
M, N, K, K_W, # shapes
5555
# expt data
5656
Betas, Gammas,
57-
GatherIndx,
57+
GatherIndx, GatherDstIndx, # GatherDstIndx is only used for launch metadata.
5858
ScatterSrcIndx, num_idxs,
5959
WriteBackIndx, writeback_size,
6060
ExptHist, ExptOffs, ExptTileOffs, ExptData,

python/triton_kernels/triton_kernels/matmul_ogs_details/_p_matmul_ogs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _p_matmul_ogs(
6363
M, N, K, K_W, # shapes
6464
# expt data
6565
Betas, Gammas,
66-
GatherIndx,
66+
GatherIndx, GatherDstIndx, # GatherDstIndx is only used for launch metadata.
6767
ScatterSrcIndx, num_idxs,
6868
WriteBackIndx, writeback_size,
6969
ExptHist, ExptOffs, ExptTileOffs, ExptData,

0 commit comments

Comments
 (0)