File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed
python/triton_kernels/triton_kernels/matmul_ogs_details Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -201,11 +201,14 @@ def make_default_opt_flags_nvidia(
201201 block_m = 128
202202 else :
203203 if tokens_per_expt <= 64 and routing_data is not None and routing_data .expt_hist is not None :
204- # Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
205204 if lhs_dtype == torch .bfloat16 and rhs_dtype == FP4 and tokens_per_expt >= 16 and torch .cuda .get_device_capability ()[0 ] >= 10 :
205+ # Ragged and likely memory bound; set the block size higher to minimize loading weights more than once.
206206 block_m = max (16 , min (triton .next_power_of_2 (8 * tokens_per_expt ), 128 ))
207207 else :
208208 block_m = max (16 , min (triton .next_power_of_2 (2 * tokens_per_expt ), 64 ))
209+ if block_m == 64 and precision_config .out_scale is not None and rhs_dtype == FP4 and torch .cuda .get_device_capability ()[0 ] >= 10 :
210+ # when having both fused_activation and mxfp8 downcast in epilogue, block_m=64 causing shared memory overflow
211+ block_m = 128
209212 else :
210213 block_m = max (16 , min (triton .next_power_of_2 (tokens_per_expt ), 128 ))
211214 # block n
You can’t perform that action at this time.
0 commit comments