Skip to content

Commit 6d25e3f

Browse files
authored
Use Blackwell FlashInfer MXFP4 MoE by default if available (#23008)
Signed-off-by: mgoin <[email protected]>
1 parent ac6eb49 commit 6d25e3f

File tree

2 files changed

+51
-19
lines changed

2 files changed

+51
-19
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,11 +762,11 @@ def __init__(
762762
self.global_num_experts = num_experts + num_redundant_experts
763763

764764
# we padding globally so EP buffer allocation works
765-
if (quant_config and quant_config.get_name() == "mxfp4"
766-
and (current_platform.is_rocm()
767-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
768-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
769-
hidden_size = round_up(hidden_size, 256)
765+
if quant_config and quant_config.get_name() == "mxfp4":
766+
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
767+
should_use_flashinfer_mxfp4)
768+
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
769+
hidden_size = round_up(hidden_size, 256)
770770

771771
# For smuggling this layer into the fused moe custom op
772772
compilation_config = vllm_config.compilation_config

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn.parameter import Parameter
77

88
from vllm import envs
9+
from vllm.logger import init_logger
910
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
1011
FusedMoEMethodBase)
1112
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
@@ -26,12 +27,38 @@
2627
from vllm.scalar_type import scalar_types
2728
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
2829
next_power_of_2, round_up)
30+
from vllm.utils.flashinfer import has_flashinfer
2931

30-
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
31-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
32-
# from flashinfer.fused_moe import cutlass_fused_moe
33-
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
34-
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
32+
logger = init_logger(__name__)
33+
34+
35+
def _should_use_flashinfer_mxfp4_bf16():
36+
"""Determine if FlashInfer MXFP4 BF16 should be used."""
37+
# If explicitly set, respect the setting
38+
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
39+
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
40+
41+
# Enable by default on SM100 if MXFP8 is not explicitly enabled
42+
if (current_platform.is_device_capability(100) and has_flashinfer()
43+
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
44+
logger.info_once(
45+
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
46+
"For faster performance, consider setting "
47+
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
48+
"though this may impact accuracy.")
49+
return True
50+
51+
return False
52+
53+
54+
def _should_use_flashinfer_mxfp4_mxfp8():
55+
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
56+
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
57+
58+
59+
def should_use_flashinfer_mxfp4():
60+
return (_should_use_flashinfer_mxfp4_mxfp8()
61+
or _should_use_flashinfer_mxfp4_bf16())
3562

3663

3764
class Mxfp4Config(QuantizationConfig):
@@ -87,12 +114,18 @@ def __init__(self, moe: FusedMoEConfig):
87114
self.moe = moe
88115
self.use_marlin = self._should_use_marlin()
89116

117+
if current_platform.is_device_capability(100) and not has_flashinfer():
118+
logger.warning_once(
119+
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
120+
"is not available. This may result in degraded performance. "
121+
"Please `pip install vllm[flashinfer]` for best results.")
122+
90123
def _should_use_marlin(self):
91124
if envs.VLLM_MXFP4_USE_MARLIN is not None:
92125
return envs.VLLM_MXFP4_USE_MARLIN
93126
if current_platform.is_cuda() and \
94-
not current_platform.has_device_capability(100):
95-
if not current_platform.is_device_capability(90):
127+
not current_platform.is_device_capability(100):
128+
if not current_platform.has_device_capability(90):
96129
# marlin kernel has better performance on ampere
97130
return True
98131
if not has_triton_kernels():
@@ -138,8 +171,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
138171
layer.hidden_size = hidden_size
139172
layer.intermediate_size_per_partition = \
140173
intermediate_size_per_partition_after_pad
141-
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
142-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
174+
elif should_use_flashinfer_mxfp4():
143175
# pad the intermediate size to be a multiple of 2 * mxfp4_block
144176
# for to hold non-uniform sharded tensor as well as swizzling
145177
# other padding to increase performance
@@ -230,8 +262,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
230262
def process_weights_after_loading(self, layer):
231263
if self.use_marlin:
232264
prepare_moe_fp4_layer_for_marlin(layer)
233-
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
234-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
265+
elif should_use_flashinfer_mxfp4():
266+
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
235267
layer.gemm1_alpha = Parameter(torch.tensor(
236268
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
237269
requires_grad=False)
@@ -478,11 +510,11 @@ def apply(
478510
logical_replica_count), (
479511
"MXFP4 are not supported with this configuration.")
480512

481-
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
482-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
513+
if should_use_flashinfer_mxfp4():
514+
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
483515
assert not self.moe.use_ep, (
484516
"EP is not supported for flashinfer mxfp4 moe backend yet.")
485-
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
517+
if _should_use_flashinfer_mxfp4_bf16():
486518
assert x.dtype == torch.bfloat16
487519
x_quant = x
488520
x_scale = None

0 commit comments

Comments
 (0)