|
6 | 6 | from torch.nn.parameter import Parameter
|
7 | 7 |
|
8 | 8 | from vllm import envs
|
| 9 | +from vllm.logger import init_logger |
9 | 10 | from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
10 | 11 | FusedMoEMethodBase)
|
11 | 12 | from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
|
26 | 27 | from vllm.scalar_type import scalar_types
|
27 | 28 | from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
|
28 | 29 | next_power_of_2, round_up)
|
| 30 | +from vllm.utils.flashinfer import has_flashinfer |
29 | 31 |
|
30 |
| -if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
31 |
| - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
32 |
| - # from flashinfer.fused_moe import cutlass_fused_moe |
33 |
| - from flashinfer import (mxfp8_quantize, shuffle_matrix_a, |
34 |
| - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) |
| 32 | +logger = init_logger(__name__) |
| 33 | + |
| 34 | + |
| 35 | +def _should_use_flashinfer_mxfp4_bf16(): |
| 36 | + """Determine if FlashInfer MXFP4 BF16 should be used.""" |
| 37 | + # If explicitly set, respect the setting |
| 38 | + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): |
| 39 | + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 |
| 40 | + |
| 41 | + # Enable by default on SM100 if MXFP8 is not explicitly enabled |
| 42 | + if (current_platform.is_device_capability(100) and has_flashinfer() |
| 43 | + and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): |
| 44 | + logger.info_once( |
| 45 | + "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " |
| 46 | + "For faster performance, consider setting " |
| 47 | + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " |
| 48 | + "though this may impact accuracy.") |
| 49 | + return True |
| 50 | + |
| 51 | + return False |
| 52 | + |
| 53 | + |
| 54 | +def _should_use_flashinfer_mxfp4_mxfp8(): |
| 55 | + """Determine if FlashInfer MXFP4 MXFP8 should be used.""" |
| 56 | + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
| 57 | + |
| 58 | + |
| 59 | +def should_use_flashinfer_mxfp4(): |
| 60 | + return (_should_use_flashinfer_mxfp4_mxfp8() |
| 61 | + or _should_use_flashinfer_mxfp4_bf16()) |
35 | 62 |
|
36 | 63 |
|
37 | 64 | class Mxfp4Config(QuantizationConfig):
|
@@ -87,12 +114,18 @@ def __init__(self, moe: FusedMoEConfig):
|
87 | 114 | self.moe = moe
|
88 | 115 | self.use_marlin = self._should_use_marlin()
|
89 | 116 |
|
| 117 | + if current_platform.is_device_capability(100) and not has_flashinfer(): |
| 118 | + logger.warning_once( |
| 119 | + "MXFP4 MoE is enabled on Blackwell but FlashInfer " |
| 120 | + "is not available. This may result in degraded performance. " |
| 121 | + "Please `pip install vllm[flashinfer]` for best results.") |
| 122 | + |
90 | 123 | def _should_use_marlin(self):
|
91 | 124 | if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
92 | 125 | return envs.VLLM_MXFP4_USE_MARLIN
|
93 | 126 | if current_platform.is_cuda() and \
|
94 |
| - not current_platform.has_device_capability(100): |
95 |
| - if not current_platform.is_device_capability(90): |
| 127 | + not current_platform.is_device_capability(100): |
| 128 | + if not current_platform.has_device_capability(90): |
96 | 129 | # marlin kernel has better performance on ampere
|
97 | 130 | return True
|
98 | 131 | if not has_triton_kernels():
|
@@ -138,8 +171,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
138 | 171 | layer.hidden_size = hidden_size
|
139 | 172 | layer.intermediate_size_per_partition = \
|
140 | 173 | intermediate_size_per_partition_after_pad
|
141 |
| - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
142 |
| - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 174 | + elif should_use_flashinfer_mxfp4(): |
143 | 175 | # pad the intermediate size to be a multiple of 2 * mxfp4_block
|
144 | 176 | # for to hold non-uniform sharded tensor as well as swizzling
|
145 | 177 | # other padding to increase performance
|
@@ -230,8 +262,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
230 | 262 | def process_weights_after_loading(self, layer):
|
231 | 263 | if self.use_marlin:
|
232 | 264 | prepare_moe_fp4_layer_for_marlin(layer)
|
233 |
| - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
234 |
| - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 265 | + elif should_use_flashinfer_mxfp4(): |
| 266 | + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a |
235 | 267 | layer.gemm1_alpha = Parameter(torch.tensor(
|
236 | 268 | [1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
237 | 269 | requires_grad=False)
|
@@ -478,11 +510,11 @@ def apply(
|
478 | 510 | logical_replica_count), (
|
479 | 511 | "MXFP4 are not supported with this configuration.")
|
480 | 512 |
|
481 |
| - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
482 |
| - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 513 | + if should_use_flashinfer_mxfp4(): |
| 514 | + from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe |
483 | 515 | assert not self.moe.use_ep, (
|
484 | 516 | "EP is not supported for flashinfer mxfp4 moe backend yet.")
|
485 |
| - if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: |
| 517 | + if _should_use_flashinfer_mxfp4_bf16(): |
486 | 518 | assert x.dtype == torch.bfloat16
|
487 | 519 | x_quant = x
|
488 | 520 | x_scale = None
|
|
0 commit comments