Skip to content

Commit af96539

Browse files
committed
use block 128
1 parent 1819e09 commit af96539

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

python/triton_kernels/tests/test_matmul.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def _build_test_op_cases():
173173
Case(*even_shape, "ragged", "bfloat16", "bfloat16", epilogue_subtile=val, swiglu_opts=(1.1, 1.4))
174174
for val in (1, 2, 4)
175175
])
176+
# swiglu together with mxfp8 downcastepilogue
177+
test_cases.extend([
178+
Case(*shape, mode, "mxfloat8_e4m3fn", "mxfloat4_e2m1", hbm_swizzling=True, split_k=split_k, swiglu_opts=(1.1, 7))
179+
for shape in [odd_shape2, even_shape] for mode in ["ragged", "batched"] for split_k in [1, 5]
180+
])
176181

177182
return test_cases
178183

python/triton_kernels/triton_kernels/matmul_details/opt_flags.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,9 @@ def make_default_opt_flags_nvidia(
210210
block_m = max(16, min(triton.next_power_of_2(8 * slice_size), 128))
211211
else:
212212
block_m = max(16, min(triton.next_power_of_2(2 * slice_size), 64))
213+
if block_m == 64 and precision_config.c_mx_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
214+
# when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
215+
block_m = 128
213216
else:
214217
block_m = max(16, min(triton.next_power_of_2(slice_size), 128))
215218
# block n

0 commit comments

Comments
 (0)