@@ -246,7 +246,8 @@ class Case:
246246 Case (300 , 400 , 400 , "ragged" , "bfloat16" , "mxfloat8_e4m3fn" , 8 , 4 , hbm_swizzling = True ),
247247 Case (300 , 400 , 400 , "batched" , "bfloat16" , "mxfloat8_e5m2" , 32 , 4 ),
248248 Case (1000 , 700 , 2 , "batched" , "bfloat16" , "mxfloat4_e2m1" , 8 , 2 ),
249- Case (1 , 2880 , 2880 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 ),
249+ # Cover (N or K) % 128 == 64 (https://github.com/triton-lang/triton/pull/7203)
250+ Case (1 , 1472 , 1472 , "ragged" , "bfloat16" , "mxfloat4_e2m1" , 128 , 4 ),
250251 Case (16 , 256 , 256 , "ragged" , "float8_e5m2" , "mxfloat4_e2m1" , 128 , 4 , hbm_swizzling = True ),
251252 Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
252253 Case (1000 , 704 , 832 , "batched" , "float8_e5m2" , "mxfloat4_e2m1" , 3 , 1 , hbm_swizzling = True ),
@@ -318,6 +319,24 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
318319 n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
319320 x_transpose , w_transpose , y_transpose ,
320321 device , opt_flags_scope ):
322+ # We catch and re-invoke pytest.skip(), because otherwise pytest may hold a reference to
323+ # the frame that called pytest.skip, including all the tensors, leading to OOM.
324+ skip_message = None
325+ try :
326+ _test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
327+ n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
328+ x_transpose , w_transpose , y_transpose ,
329+ device , opt_flags_scope )
330+ except pytest .skip .Exception as e :
331+ skip_message = str (e )
332+
333+ if skip_message is not None :
334+ pytest .skip (skip_message )
335+
336+ def _test_op (m , n , k , split_k , do_gather , do_scatter , fused_scatter , inner_expt_opt , has_y_gammas , is_persistent , n_expts_tot ,
337+ n_expts_act , mode , act_dtype_str , weight_dtype_str , block_m , hbm_swizzling , colmajor_mxfp_weight , epilogue_subtile ,
338+ x_transpose , w_transpose , y_transpose ,
339+ device , opt_flags_scope ):
321340 # TODO: remove when Triton FP8 supports proper RTNE
322341 if is_cuda ():
323342 if "float8" in weight_dtype_str and torch .cuda .get_device_capability ()[0 ] < 9 :
@@ -327,8 +346,6 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
327346 if weight_dtype_str .startswith ("mx" ):
328347 if "float8" in act_dtype_str and torch .cuda .get_device_capability ()[0 ] < 10 :
329348 pytest .skip ("float8 x mx not supported with cuda capability < 10" )
330- if n == 2880 and k == 2880 and torch .cuda .get_device_capability ()[0 ] < 9 :
331- pytest .skip ("Not enough memory on A100" )
332349
333350 elif is_hip ():
334351 if "float8" in act_dtype_str and "mx" in weight_dtype_str and not is_hip_cdna4 ():
@@ -366,8 +383,21 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
366383 pytest .skip ("Hopper swizzling acts on a 64x64 tile (4x1 mma tiles)." )
367384
368385 expt_is_inner = (inner_expt_opt is not None )
369- if expt_is_inner and (mode != "ragged" or "mx" in act_dtype_str or "mx" in weight_dtype_str ):
370- pytest .skip ("Not supported yet" )
386+ if expt_is_inner :
387+ if mode != "ragged" :
388+ pytest .skip ("inner_expt_opt only meaningful with ragged" )
389+ if "mx" in act_dtype_str and inner_expt_opt != "pad_x" :
390+ pytest .skip ("inner_expt_opt and act mx only supported with pad_x" )
391+ if "mx" in weight_dtype_str :
392+ if inner_expt_opt != "pad_w" :
393+ pytest .skip ("inner_expt_opt and weight mx only supported with pad_w" )
394+ if is_persistent and not hbm_swizzling :
395+ pytest .skip ("FIXME: Fatal Python error: Aborted" )
396+ if is_hip ():
397+ if act_dtype_str == "bfloat16" :
398+ pytest .skip ("FIXME: failed to translate module to LLVM IR" )
399+ if hbm_swizzling :
400+ pytest .skip ("NYI: nner_expt_opt and HBM swizzling" )
371401
372402 # launch metadata for batched / mx types may not work yet.
373403 torch .manual_seed (0 )
@@ -399,6 +429,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
399429 opt_flags .update_opt_flags_constraints (constraints )
400430
401431 weight_mxfp = weight_dtype_str .startswith ("mx" )
432+ weight_mxfp4 = weight_mxfp and "float4" in weight_dtype_str
402433 if weight_mxfp :
403434 weight_dtype_str = weight_dtype_str [2 :]
404435 act_mxfp8 = act_dtype_str .startswith ("mx" )
@@ -422,6 +453,13 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
422453 rdata = gindx = sindx = None
423454
424455 padding_block_k = 32
456+ if hbm_swizzling :
457+ if torch .cuda .get_device_capability ()[0 ] >= 10 :
458+ # Blackwell scale swizzling constraint
459+ # https://github.com/triton-lang/triton/blob/814b862166c756d9f33238844f4ac047e0243388/python/triton_kernels/triton_kernels/tensor_details/layout_details/blackwell_scale.py#L45
460+ padding_block_k = 128
461+ elif not is_persistent :
462+ padding_block_k = 64
425463 x_tri , w_tri , bias_tri , gs0_tri , gs1_tri = init_compute_data (m , n , k , rdata , gindx , sindx , n_expts_tot , n_expts_act ,
426464 mode , torch .bfloat16 if act_mxfp8 else act_dtype , #
427465 torch .bfloat16 if weight_mxfp else weight_dtype ,
@@ -457,7 +495,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
457495 # compute layouts
458496 w_layout , w_layout_opts = layout .StridedLayout , dict ()
459497 w_scale_layout , w_scale_layout_opts = layout .StridedLayout , dict ()
460- if hbm_swizzling and "float4" in weight_dtype_str :
498+ if hbm_swizzling and weight_mxfp4 :
461499 w_layout , w_layout_opts = layout .make_default_matmul_mxfp4_w_layout (mx_axis = mx_axis )
462500 w_scale_layout , w_scale_layout_opts = layout .make_default_matmul_mxfp4_w_scale_layout (
463501 mx_axis = mx_axis , num_warps = 8 )
@@ -466,7 +504,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
466504 if colmajor_mxfp_weight :
467505 w_tri , w_scale_tri = downcast_to_mxfp (w_tri , weight_dtype , axis = mx_axis )
468506 w_ref = upcast_from_mxfp (w_tri , w_scale_tri , torch .bfloat16 , axis = mx_axis )
469- w_tri_dtype = FP4 if "float4" in weight_dtype_str else weight_dtype
507+ w_tri_dtype = FP4 if weight_mxfp4 else weight_dtype
470508 w_tri = wrap_torch_tensor (w_tri , w_tri_dtype )
471509 w_scale_tri = wrap_torch_tensor (w_scale_tri )
472510 # convert layouts
@@ -568,8 +606,8 @@ def _pad_and_block(x: torch.Tensor) -> torch.Tensor:
568606 tri_y = matmul_ogs (x_tri , w_tri , bias_tri , rdata , gindx , sindx , precision_opt ,
569607 gammas = gs1_ref , epilogue = epilogue , y = y_tri_in ,
570608 inner_routing_data = inner_routing_data )
571- except (opt_flags .InapplicableConstraint , NotImplementedError ):
572- pytest .skip ("inapplicable opt_flags constraint" )
609+ except (opt_flags .InapplicableConstraint , NotImplementedError ) as e :
610+ pytest .skip (f "inapplicable opt_flags constraint { e } " )
573611 if y_tri_in is not None :
574612 assert tri_y .data_ptr () == y_tri_in .data_ptr ()
575613 assert tri_y .shape == y_tri_in .shape
@@ -602,7 +640,7 @@ def scale(val, scal):
602640 ref_y = upcast_from_mxfp_torch (ref_y_quant , ref_y_scale , target_dtype = ref_y .dtype , axis = - 1 )
603641 maxtol = 4e-1
604642 rmstol = 4e-2
605- elif weight_mxfp and "float4_e2m1" in weight_dtype_str :
643+ elif weight_mxfp4 :
606644 if act_is_float8 :
607645 maxtol = 8e-2
608646 else :
0 commit comments