Skip to content

Commit 40ca782

Browse files
committed
final updates
1 parent 62c2c0a commit 40ca782

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -791,9 +791,11 @@ def __init__(
791791
# we padding globally so EP buffer allocation works
792792
if quant_config and quant_config.get_name() == "mxfp4":
793793
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
794-
should_use_flashinfer_mxfp4)
795-
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
794+
should_use_flashinfer_mxfp4, should_use_flashinfer_mxfp4_bf16)
795+
if current_platform.is_rocm() or (should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100)):
796796
hidden_size = round_up(hidden_size, 256)
797+
elif should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90):
798+
hidden_size = round_up(hidden_size, 128)
797799

798800
# For smuggling this layer into the fused moe custom op
799801
compilation_config = vllm_config.compilation_config

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
logger = init_logger(__name__)
3333

3434

35-
def _should_use_flashinfer_mxfp4_bf16():
35+
def should_use_flashinfer_mxfp4_bf16():
3636
"""Determine if FlashInfer MXFP4 BF16 should be used."""
3737
# If explicitly set, respect the setting
3838
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
@@ -60,7 +60,7 @@ def _should_use_flashinfer_mxfp4_mxfp8():
6060

6161
def should_use_flashinfer_mxfp4():
6262
return (_should_use_flashinfer_mxfp4_mxfp8()
63-
or _should_use_flashinfer_mxfp4_bf16())
63+
or should_use_flashinfer_mxfp4_bf16())
6464

6565

6666
class Mxfp4Config(QuantizationConfig):
@@ -182,11 +182,12 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
182182
intermediate_size_per_partition_after_pad = round_up(
183183
intermediate_size_per_partition, 256)
184184
hidden_size = round_up(hidden_size, 256)
185-
elif _should_use_flashinfer_mxfp4_bf16(
185+
elif should_use_flashinfer_mxfp4_bf16(
186186
) and current_platform.is_device_capability(
187187
90) or current_platform.is_rocm():
188188
intermediate_size_per_partition_after_pad = round_up(
189189
intermediate_size_per_partition, 128)
190+
hidden_size = round_up(hidden_size, 128)
190191
else:
191192
intermediate_size_per_partition_after_pad = round_up(
192193
intermediate_size_per_partition, 64)
@@ -391,7 +392,7 @@ def swap_every_two_rows(x, axis=-1):
391392
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
392393
self.num_experts, -1),
393394
requires_grad=False)
394-
elif _should_use_flashinfer_mxfp4_bf16(
395+
elif should_use_flashinfer_mxfp4_bf16(
395396
) and current_platform.is_device_capability(90):
396397
assert layer.w13_weight.dtype == torch.uint8, (
397398
f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, "
@@ -501,7 +502,6 @@ def swap_every_two_rows(x, axis=-1):
501502

502503
layer.w2_weight_scale = torch.nn.Parameter(
503504
w2_scales_interleaved.cuda(), requires_grad=False)
504-
505505
else:
506506
from triton_kernels.matmul_ogs import FlexCtx, PrecisionConfig
507507

@@ -633,7 +633,7 @@ def apply(
633633
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
634634
assert not self.moe.use_ep, (
635635
"EP is not supported for flashinfer mxfp4 moe backend yet.")
636-
if _should_use_flashinfer_mxfp4_bf16():
636+
if should_use_flashinfer_mxfp4_bf16():
637637
assert x.dtype == torch.bfloat16
638638
x_quant = x
639639
x_scale = None
@@ -670,7 +670,7 @@ def apply(
670670
True, # do finalize
671671
)[0]
672672
return trtllm_gen_output
673-
elif _should_use_flashinfer_mxfp4_bf16(
673+
elif should_use_flashinfer_mxfp4_bf16(
674674
) and current_platform.is_device_capability(90):
675675
from vllm.utils.flashinfer import (autotune,
676676
flashinfer_cutlass_fused_moe)
@@ -695,14 +695,16 @@ def apply(
695695
e_score_correction_bias=e_score_correction_bias,
696696
)
697697

698-
with torch.inference_mode(), autotune(self.flashinfer_autotune):
699-
output = flashinfer_cutlass_fused_moe(
698+
output = torch.empty_like(x, dtype=torch.bfloat16)
699+
with autotune(self.flashinfer_autotune):
700+
_ = flashinfer_cutlass_fused_moe(
700701
input=x,
701-
token_selected_experts=topk_ids,
702+
token_selected_experts=topk_ids.to(torch.int).contiguous(),
702703
token_final_scales=topk_weights,
703704
fc1_expert_weights=layer.w13_weight,
704705
fc2_expert_weights=layer.w2_weight,
705706
output_dtype=torch.bfloat16,
707+
output=output,
706708
quant_scales=quant_scales,
707709
fc1_expert_biases=layer.w13_bias,
708710
fc2_expert_biases=layer.w2_bias,
@@ -714,8 +716,9 @@ def apply(
714716
ep_size=self.moe.ep_size,
715717
ep_rank=self.moe.ep_rank,
716718
use_w4_group_scaling=True,
717-
)[0]
718-
self.flashinfer_autotune = False
719+
)
720+
721+
self.flashinfer_autotune = False
719722
return output
720723
else:
721724
return triton_kernel_moe_forward(

0 commit comments

Comments
 (0)