diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index 6b50ca8a11e3..98f216dd4ab9 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -173,6 +173,11 @@ def _build_test_op_cases(): Case(*even_shape, "ragged", "bfloat16", "bfloat16", epilogue_subtile=val, swiglu_opts=(1.1, 1.4)) for val in (1, 2, 4) ]) + # swiglu together with mxfp8 downcastepilogue + test_cases.extend([ + Case(*shape, mode, "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True, split_k=split_k, swiglu_opts=(1.1, 7)) + for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5] + ]) return test_cases diff --git a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py index de2c3f2bd00c..3212102ef21d 100644 --- a/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py +++ b/python/triton_kernels/triton_kernels/matmul_details/opt_flags.py @@ -210,6 +210,9 @@ def make_default_opt_flags_nvidia( block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128)) else: block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64)) + if block_m == 64 and precision_config.c_mx_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10: + # when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow + block_m = 128 else: block_m = max(16, min(triton.next_power_of_2(slice_size), 128)) # block n