22import pytest
33import torch
44from typing import Union
5+ import triton
56# routing utilities
67from triton_kernels .routing import routing
78# matmul utilities
@@ -243,6 +244,9 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
243244 # Automatic padding not implemented for Hopper swizzle
244245 pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
245246
247+ # launch metadata for batched / mx types may not work yet.
248+ test_launch_metadata = (mode == "ragged" ) and ("mx" not in weight_dtype_str )
249+
246250 torch .manual_seed (0 )
247251
248252 block_k = None
@@ -314,8 +318,48 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
314318
315319 if w_tri .shape [0 ] == 1 :
316320 # Test the case when weight has dim 2, i.e., shape (K, N).
317- w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ ()
318- w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ ()
321+ w_tri = w_tri .squeeze (0 ).detach ().requires_grad_ (test_bwd )
322+ w_ref = w_ref .squeeze (0 ).detach ().requires_grad_ (test_bwd )
323+
324+ if test_launch_metadata :
325+
326+ def _clobber (t , used_mask ):
327+ # Fill the unread part of the tensor with garbage, to be sure that
328+ # we don't actually read from the part.
329+ if len (used_mask ) == 1 :
330+ return
331+ elif t .element_size () == 1 :
332+ t .view (torch .int8 )[~ used_mask ] = 127
333+ else :
334+ t [~ used_mask ] = torch .inf
335+
336+ if rdata is not None :
337+ n_tokens = rdata .expt_hist .sum ().item ()
338+ used_expts = (rdata .expt_hist > 0 )
339+ _clobber (w_tri , used_expts )
340+ n_w_bytes = used_expts .sum ().item () * n * k * w_tri .element_size ()
341+ else :
342+ n_tokens = m
343+ n_w_bytes = w_tri .numel () * w_tri .element_size ()
344+
345+ if gindx is not None :
346+ used_x_rows = (gindx .dst_indx .view (- 1 , n_expts_act ) != - 1 ).any (dim = 1 )
347+ _clobber (x_tri , used_x_rows )
348+ n_x_bytes = used_x_rows .sum ().item () * k * x_tri .element_size ()
349+ elif rdata is not None :
350+ n_x_bytes = n_tokens * k * x_tri .element_size ()
351+ else :
352+ n_x_bytes = x_tri .numel () * x_tri .element_size ()
353+
354+ nbytes = None
355+
356+ def _hook (launch_metadata ):
357+ nonlocal nbytes
358+ metadata = launch_metadata .get ()
359+ if "matmul_ogs" in metadata ["name" ]:
360+ nbytes = metadata ["bytes" ]
361+
362+ triton .knobs .runtime .launch_enter_hook = _hook
319363
320364 if mode == "batched" :
321365 rdata , gindx , sindx = None , None , None
@@ -327,6 +371,16 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
327371 sep_scatter = mode == "ragged" and do_scatter and n_expts_act > 1 and split_k == 1
328372 y_scale = flex .out_data .expected_scale if act_is_float8 else 1
329373
374+ if test_launch_metadata :
375+ if gindx is not None :
376+ n_y_bytes = (gindx .src_indx != - 1 ).sum ().item () * n * tri_y .element_size ()
377+ elif rdata is not None :
378+ n_y_bytes = n_tokens * n * tri_y .element_size ()
379+ else :
380+ n_y_bytes = tri_y .numel () * tri_y .element_size ()
381+ assert nbytes == n_x_bytes + n_y_bytes + n_w_bytes
382+ triton .knobs .runtime .launch_enter_hook = None
383+
330384 def round_x (x , idx ):
331385 return x .to (act_dtype ).to (torch .float32 ) if sep_gather else x
332386
0 commit comments