Skip to content

Commit c98c1db

Browse files
committed
typo
Signed-off-by: Duncan Moss <[email protected]>
1 parent f329657 commit c98c1db

File tree

1 file changed

+7
-7
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+7
-7
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ def _should_use_flashinfer_mxfp4_bf16():
3939
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
4040

4141
# Enable by default on SM100 if MXFP8 is not explicitly enabled
42-
if (current_platform.is_device_capability(100) and has_flashinfer()
42+
if (current_platform.is_device_capability(100) or current_platform.is_device_capability(90) and has_flashinfer()
4343
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
4444
logger.info_once(
45-
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
45+
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell and Hopper. "
4646
"For faster performance, consider setting "
4747
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
4848
"though this may impact accuracy.")
@@ -172,14 +172,14 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
172172
layer.hidden_size = hidden_size
173173
layer.intermediate_size_per_partition = \
174174
intermediate_size_per_partition_after_pad
175-
elif should_use_flashinfer_mxfp4():
175+
elif should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100):
176176
# pad the intermediate size to be a multiple of 2 * mxfp4_block
177177
# for to hold non-uniform sharded tensor as well as swizzling
178178
# other padding to increase performance
179179
intermediate_size_per_partition_after_pad = round_up(
180180
intermediate_size_per_partition, 256)
181181
hidden_size = round_up(hidden_size, 256)
182-
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16) and current_platform.is_device_capability(90):
182+
elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90):
183183
intermediate_size_per_partition_after_pad = round_up(
184184
intermediate_size_per_partition, 128)
185185
elif current_platform.is_rocm():
@@ -388,7 +388,7 @@ def swap_every_two_rows(x, axis=-1):
388388
layer.w2_bias = Parameter(torch.stack(gemm2_bias_shuffled).reshape(
389389
self.num_experts, -1),
390390
requires_grad=False)
391-
elif envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 and current_platform.is_device_capability(90):
391+
elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90):
392392
assert layer.w13_weight.dtype == torch.uint8, f"layer.w13_weight.dtype: {layer.w13_weight.dtype}, expected: {torch.uint8}"
393393
assert layer.w2_weight.dtype == torch.uint8, f"layer.w2_weight.dtype: {layer.w2_weight.dtype}, expected: {torch.uint8}"
394394
assert layer.w13_weight_scale.dtype == torch.uint8, f"layer.w13_weight_scale.dtype: {layer.w13_weight_scale.dtype}, expected: {torch.uint8}"
@@ -604,7 +604,7 @@ def apply(
604604
logical_replica_count), (
605605
"MXFP4 are not supported with this configuration.")
606606

607-
if should_use_flashinfer_mxfp4():
607+
if should_use_flashinfer_mxfp4() and current_platform.is_device_capability(100):
608608
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
609609
assert not self.moe.use_ep, (
610610
"EP is not supported for flashinfer mxfp4 moe backend yet.")
@@ -645,7 +645,7 @@ def apply(
645645
True, # do finalize
646646
)[0]
647647
return trtllm_gen_output
648-
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16) and current_platform.is_device_capability(90):
648+
elif _should_use_flashinfer_mxfp4_bf16() and current_platform.is_device_capability(90):
649649

650650
assert x.dtype == torch.bfloat16
651651

0 commit comments

Comments
 (0)