32
32
logger = init_logger (__name__ )
33
33
34
34
35
- def _should_use_flashinfer_mxfp4_bf16 ():
35
+ def should_use_flashinfer_mxfp4_bf16 ():
36
36
"""Determine if FlashInfer MXFP4 BF16 should be used."""
37
37
# If explicitly set, respect the setting
38
38
if envs .is_set ("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16" ):
@@ -60,7 +60,7 @@ def _should_use_flashinfer_mxfp4_mxfp8():
60
60
61
61
def should_use_flashinfer_mxfp4 ():
62
62
return (_should_use_flashinfer_mxfp4_mxfp8 ()
63
- or _should_use_flashinfer_mxfp4_bf16 ())
63
+ or should_use_flashinfer_mxfp4_bf16 ())
64
64
65
65
66
66
class Mxfp4Config (QuantizationConfig ):
@@ -182,11 +182,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
182
182
intermediate_size_per_partition_after_pad = round_up (
183
183
intermediate_size_per_partition , 256 )
184
184
hidden_size = round_up (hidden_size , 256 )
185
- elif _should_use_flashinfer_mxfp4_bf16 (
185
+ elif should_use_flashinfer_mxfp4_bf16 (
186
186
) and current_platform .is_device_capability (
187
187
90 ) or current_platform .is_rocm ():
188
188
intermediate_size_per_partition_after_pad = round_up (
189
189
intermediate_size_per_partition , 128 )
190
+ hidden_size = round_up (hidden_size , 128 )
190
191
else :
191
192
intermediate_size_per_partition_after_pad = round_up (
192
193
intermediate_size_per_partition , 64 )
@@ -391,7 +392,7 @@ def swap_every_two_rows(x, axis=-1):
391
392
layer .w2_bias = Parameter (torch .stack (gemm2_bias_shuffled ).reshape (
392
393
self .num_experts , - 1 ),
393
394
requires_grad = False )
394
- elif _should_use_flashinfer_mxfp4_bf16 (
395
+ elif should_use_flashinfer_mxfp4_bf16 (
395
396
) and current_platform .is_device_capability (90 ):
396
397
assert layer .w13_weight .dtype == torch .uint8 , (
397
398
f"layer.w13_weight.dtype: { layer .w13_weight .dtype } , "
@@ -501,7 +502,6 @@ def swap_every_two_rows(x, axis=-1):
501
502
502
503
layer .w2_weight_scale = torch .nn .Parameter (
503
504
w2_scales_interleaved .cuda (), requires_grad = False )
504
-
505
505
else :
506
506
from triton_kernels .matmul_ogs import FlexCtx , PrecisionConfig
507
507
@@ -633,7 +633,7 @@ def apply(
633
633
from flashinfer import mxfp8_quantize , trtllm_fp4_block_scale_moe
634
634
assert not self .moe .use_ep , (
635
635
"EP is not supported for flashinfer mxfp4 moe backend yet." )
636
- if _should_use_flashinfer_mxfp4_bf16 ():
636
+ if should_use_flashinfer_mxfp4_bf16 ():
637
637
assert x .dtype == torch .bfloat16
638
638
x_quant = x
639
639
x_scale = None
@@ -670,7 +670,7 @@ def apply(
670
670
True , # do finalize
671
671
)[0 ]
672
672
return trtllm_gen_output
673
- elif _should_use_flashinfer_mxfp4_bf16 (
673
+ elif should_use_flashinfer_mxfp4_bf16 (
674
674
) and current_platform .is_device_capability (90 ):
675
675
from vllm .utils .flashinfer import (autotune ,
676
676
flashinfer_cutlass_fused_moe )
@@ -695,14 +695,16 @@ def apply(
695
695
e_score_correction_bias = e_score_correction_bias ,
696
696
)
697
697
698
- with torch .inference_mode (), autotune (self .flashinfer_autotune ):
699
- output = flashinfer_cutlass_fused_moe (
698
+ output = torch .empty_like (x , dtype = torch .bfloat16 )
699
+ with autotune (self .flashinfer_autotune ):
700
+ _ = flashinfer_cutlass_fused_moe (
700
701
input = x ,
701
- token_selected_experts = topk_ids ,
702
+ token_selected_experts = topk_ids . to ( torch . int ). contiguous () ,
702
703
token_final_scales = topk_weights ,
703
704
fc1_expert_weights = layer .w13_weight ,
704
705
fc2_expert_weights = layer .w2_weight ,
705
706
output_dtype = torch .bfloat16 ,
707
+ output = output ,
706
708
quant_scales = quant_scales ,
707
709
fc1_expert_biases = layer .w13_bias ,
708
710
fc2_expert_biases = layer .w2_bias ,
@@ -714,8 +716,9 @@ def apply(
714
716
ep_size = self .moe .ep_size ,
715
717
ep_rank = self .moe .ep_rank ,
716
718
use_w4_group_scaling = True ,
717
- )[0 ]
718
- self .flashinfer_autotune = False
719
+ )
720
+
721
+ self .flashinfer_autotune = False
719
722
return output
720
723
else :
721
724
return triton_kernel_moe_forward (
0 commit comments