Skip to content

Commit a37606b

Browse files
committed
use block 128
1 parent 9fbf44f commit a37606b

File tree

1 file changed

+4
-1
lines changed
  • python/triton_kernels/triton_kernels/matmul_ogs_details

1 file changed

+4
-1
lines changed

python/triton_kernels/triton_kernels/matmul_ogs_details/opt_flags.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,14 @@ def make_default_opt_flags_nvidia(
201201
block_m = 128
202202
else:
203203
if tokens_per_expt <= 64 and routing_data is not None and routing_data.expt_hist is not None:
204-
# Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
205204
if lhs_dtype == torch.bfloat16 and rhs_dtype == FP4 and tokens_per_expt >= 16 and torch.cuda.get_device_capability()[0] >= 10:
205+
# Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
206206
block_m = max(16, min(triton.next_power_of_2(8 * tokens_per_expt), 128))
207207
else:
208208
block_m = max(16, min(triton.next_power_of_2(2 * tokens_per_expt), 64))
209+
if block_m == 64 and precision_config.out_scale is not None and rhs_dtype == FP4 and torch.cuda.get_device_capability()[0] >= 10:
210+
# when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
211+
block_m = 128
209212
else:
210213
block_m = max(16, min(triton.next_power_of_2(tokens_per_expt), 128))
211214
# block n

0 commit comments

Comments
 (0)