@@ -39,10 +39,10 @@ def _should_use_flashinfer_mxfp4_bf16():
39
39
return envs .VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
40
40
41
41
# Enable by default on SM100 if MXFP8 is not explicitly enabled
42
- if (current_platform .is_device_capability (100 ) and has_flashinfer ()
42
+ if (current_platform .is_device_capability (100 ) or current_platform . is_device_capability ( 90 ) and has_flashinfer ()
43
43
and not envs .is_set ("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8" )):
44
44
logger .info_once (
45
- "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
45
+ "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper . "
46
46
"For faster performance, consider setting "
47
47
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
48
48
"though this may impact accuracy." )
@@ -172,14 +172,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
172
172
layer .hidden_size = hidden_size
173
173
layer .intermediate_size_per_partition = \
174
174
intermediate_size_per_partition_after_pad
175
- elif should_use_flashinfer_mxfp4 ():
175
+ elif should_use_flashinfer_mxfp4 () and current_platform . is_device_capability ( 100 ) :
176
176
# pad the intermediate size to be a multiple of 2 * mxfp4_block
177
177
# for to hold non-uniform sharded tensor as well as swizzling
178
178
# other padding to increase performance
179
179
intermediate_size_per_partition_after_pad = round_up (
180
180
intermediate_size_per_partition , 256 )
181
181
hidden_size = round_up (hidden_size , 256 )
182
- elif ( envs . VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 ) and current_platform .is_device_capability (90 ):
182
+ elif _should_use_flashinfer_mxfp4_bf16 ( ) and current_platform .is_device_capability (90 ):
183
183
intermediate_size_per_partition_after_pad = round_up (
184
184
intermediate_size_per_partition , 128 )
185
185
elif current_platform .is_rocm ():
@@ -388,7 +388,7 @@ def swap_every_two_rows(x, axis=-1):
388
388
layer .w2_bias = Parameter (torch .stack (gemm2_bias_shuffled ).reshape (
389
389
self .num_experts , - 1 ),
390
390
requires_grad = False )
391
- elif envs . VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 and current_platform .is_device_capability (90 ):
391
+ elif _should_use_flashinfer_mxfp4_bf16 () and current_platform .is_device_capability (90 ):
392
392
assert layer .w13_weight .dtype == torch .uint8 , f"layer.w13_weight.dtype: { layer .w13_weight .dtype } , expected: { torch .uint8 } "
393
393
assert layer .w2_weight .dtype == torch .uint8 , f"layer.w2_weight.dtype: { layer .w2_weight .dtype } , expected: { torch .uint8 } "
394
394
assert layer .w13_weight_scale .dtype == torch .uint8 , f"layer.w13_weight_scale.dtype: { layer .w13_weight_scale .dtype } , expected: { torch .uint8 } "
@@ -604,7 +604,7 @@ def apply(
604
604
logical_replica_count ), (
605
605
"MXFP4 are not supported with this configuration." )
606
606
607
- if should_use_flashinfer_mxfp4 ():
607
+ if should_use_flashinfer_mxfp4 () and current_platform . is_device_capability ( 100 ) :
608
608
from flashinfer import mxfp8_quantize , trtllm_fp4_block_scale_moe
609
609
assert not self .moe .use_ep , (
610
610
"EP is not supported for flashinfer mxfp4 moe backend yet." )
@@ -645,7 +645,7 @@ def apply(
645
645
True , # do finalize
646
646
)[0 ]
647
647
return trtllm_gen_output
648
- elif ( envs . VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 ) and current_platform .is_device_capability (90 ):
648
+ elif _should_use_flashinfer_mxfp4_bf16 ( ) and current_platform .is_device_capability (90 ):
649
649
650
650
assert x .dtype == torch .bfloat16
651
651
0 commit comments