Skip to content

Commit 6d0734c

Browse files
kaixihmgoin
andauthored
[NVIDIA] Add SM100 Flashinfer MoE blockscale fp8 backend for low latency (#20645)
Signed-off-by: kaixih <[email protected]> Signed-off-by: mgoin <[email protected]> Co-authored-by: mgoin <[email protected]>
1 parent 7d94577 commit 6d0734c

File tree

6 files changed

+187
-31
lines changed

6 files changed

+187
-31
lines changed

vllm/envs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@
119119
VLLM_TPU_BUCKET_PADDING_GAP: int = 0
120120
VLLM_TPU_MOST_MODEL_LEN: Optional[int] = None
121121
VLLM_USE_DEEP_GEMM: bool = False
122-
VLLM_USE_FLASHINFER_MOE: bool = False
122+
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
123+
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
123124
VLLM_XGRAMMAR_CACHE_MB: int = 0
124125
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
125126
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
@@ -854,9 +855,13 @@ def get_vllm_port() -> Optional[int]:
854855
"VLLM_USE_DEEP_GEMM":
855856
lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "0"))),
856857

858+
# Allow use of FlashInfer MoE kernels for fused moe ops.
859+
"VLLM_USE_FLASHINFER_MOE_FP8":
860+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP8", "0"))),
861+
857862
# Allow use of FlashInfer CUTLASS kernels for fused moe ops.
858-
"VLLM_USE_FLASHINFER_MOE":
859-
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE", "0"))),
863+
"VLLM_USE_FLASHINFER_MOE_FP4":
864+
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_MOE_FP4", "0"))),
860865

861866
# Control the cache sized used by the xgrammar compiler. The default
862867
# of 512 MB should be enough for roughly 1000 JSON schemas.

vllm/model_executor/layers/fused_moe/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def use_deepep_ll_kernels(self):
191191

192192
@property
193193
def use_flashinfer_cutlass_kernels(self):
194-
return (envs.VLLM_USE_FLASHINFER_MOE
194+
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
195195
and has_flashinfer_cutlass_fused_moe())
196196

197197
@staticmethod

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 99 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
2929
TopKWeightAndReduceNoOP)
3030
from vllm.model_executor.layers.fused_moe.utils import (
31-
_resize_cache, moe_kernel_quantize_input)
31+
_resize_cache, moe_kernel_quantize_input, per_token_group_quant_fp8)
3232
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
3333
dequant_mxfp4)
3434
from vllm.platforms import current_platform
@@ -1061,6 +1061,104 @@ def inplace_fused_experts_fake(
10611061
)
10621062

10631063

1064+
def next_positive_power_of_2(x: int) -> int:
1065+
if x < 1:
1066+
return 1
1067+
return 1 << (x - 1).bit_length()
1068+
1069+
1070+
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
1071+
# Guess tokens per expert assuming perfect expert distribution first.
1072+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
1073+
# And pad the number to the next power of 2.
1074+
tile_tokens_dim = next_positive_power_of_2(num_tokens_per_expert)
1075+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
1076+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
1077+
return tile_tokens_dim
1078+
1079+
1080+
def flashinfer_fused_moe_blockscale_fp8(
1081+
routing_logits: torch.Tensor,
1082+
routing_bias: torch.Tensor,
1083+
x: torch.Tensor,
1084+
w13_weight: torch.Tensor,
1085+
w13_weight_scale_inv: torch.Tensor,
1086+
w2_weight: torch.Tensor,
1087+
w2_weight_scale_inv: torch.Tensor,
1088+
global_num_experts: int,
1089+
top_k: int,
1090+
num_expert_group: int,
1091+
topk_group: int,
1092+
intermediate_size: int,
1093+
expert_offset: int,
1094+
local_num_experts: int,
1095+
block_shape: list[int],
1096+
routed_scaling: float = 1.0) -> torch.Tensor:
1097+
from vllm.utils.flashinfer import flashinfer_trtllm_fp8_block_scale_moe
1098+
assert top_k <= global_num_experts
1099+
assert top_k <= 8
1100+
assert topk_group <= 4
1101+
assert global_num_experts > num_expert_group
1102+
assert global_num_experts % num_expert_group == 0
1103+
assert global_num_experts % 4 == 0
1104+
assert top_k < (topk_group * global_num_experts / num_expert_group)
1105+
assert block_shape == [128, 128]
1106+
1107+
a_q, a_sf = per_token_group_quant_fp8(x, block_shape[1])
1108+
# NOTE: scales of hidden states have to be transposed!
1109+
a_sf_t = a_sf.t().contiguous()
1110+
return flashinfer_trtllm_fp8_block_scale_moe(
1111+
routing_logits=routing_logits,
1112+
routing_bias=routing_bias,
1113+
hidden_states=a_q,
1114+
hidden_states_scale=a_sf_t,
1115+
gemm1_weights=w13_weight,
1116+
gemm1_weights_scale=w13_weight_scale_inv,
1117+
gemm2_weights=w2_weight,
1118+
gemm2_weights_scale=w2_weight_scale_inv,
1119+
num_experts=global_num_experts,
1120+
top_k=top_k,
1121+
n_group=num_expert_group,
1122+
topk_group=topk_group,
1123+
intermediate_size=intermediate_size,
1124+
local_expert_offset=expert_offset,
1125+
local_num_experts=local_num_experts,
1126+
routed_scaling_factor=routed_scaling,
1127+
tile_tokens_dim=_get_tile_tokens_dim(x.shape[0], top_k,
1128+
global_num_experts),
1129+
routing_method_type=2, # DeepSeek-styled routing method
1130+
)
1131+
1132+
1133+
def flashinfer_fused_moe_blockscale_fp8_fake(
1134+
routing_logits: torch.Tensor,
1135+
routing_bias: torch.Tensor,
1136+
x: torch.Tensor,
1137+
w13_weight: torch.Tensor,
1138+
w13_weight_scale_inv: torch.Tensor,
1139+
w2_weight: torch.Tensor,
1140+
w2_weight_scale_inv: torch.Tensor,
1141+
global_num_experts: int,
1142+
top_k: int,
1143+
num_expert_group: int,
1144+
topk_group: int,
1145+
intermediate_size: int,
1146+
expert_offset: int,
1147+
local_num_experts: int,
1148+
block_shape: list[int],
1149+
routed_scaling: float = 1.0) -> torch.Tensor:
1150+
return torch.empty_like(x)
1151+
1152+
1153+
direct_register_custom_op(
1154+
op_name="flashinfer_fused_moe_blockscale_fp8",
1155+
op_func=flashinfer_fused_moe_blockscale_fp8,
1156+
mutates_args=[],
1157+
fake_impl=flashinfer_fused_moe_blockscale_fp8_fake,
1158+
tags=(torch.Tag.needs_fixed_stride_order, ),
1159+
)
1160+
1161+
10641162
def outplace_fused_experts(
10651163
hidden_states: torch.Tensor,
10661164
w1: torch.Tensor,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from vllm.scalar_type import scalar_types
4444
from vllm.utils import has_deep_gemm
4545
from vllm.utils.deep_gemm import is_blackwell_deep_gemm_used
46+
from vllm.utils.flashinfer import has_flashinfer_moe
4647

4748
if TYPE_CHECKING:
4849
from vllm.model_executor.models.utils import WeightsMapper
@@ -52,6 +53,11 @@
5253
logger = init_logger(__name__)
5354

5455

56+
def _swap_w13_to_w31(x: torch.Tensor) -> torch.Tensor:
57+
return x.reshape(-1, 2, x.shape[-2] // 2,
58+
x.shape[-1]).flip(dims=[1]).reshape(x.shape)
59+
60+
5561
def _is_col_major(x: torch.Tensor) -> bool:
5662
assert x.dim() == 3
5763
b, m, n = x.shape
@@ -473,6 +479,11 @@ def __init__(self, quant_config: Fp8Config):
473479
self.quant_config = quant_config
474480
self.block_quant = self.quant_config.weight_block_size is not None
475481

482+
self.flashinfer_moe_enabled = False
483+
if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe():
484+
logger.info_once(
485+
"Using FlashInfer MoE FP8 kernels for Fp8MoEMethod.")
486+
self.flashinfer_moe_enabled = True
476487
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
477488
# kernel for fast weight-only FP8 quantization
478489
self.use_marlin = (not current_platform.has_device_capability(89)
@@ -674,6 +685,14 @@ def process_weights_after_loading(self, layer: Module) -> None:
674685
normalize_e4m3fn_to_e4m3fnuz(
675686
layer.w2_weight, layer.w2_weight_scale_inv,
676687
layer.w2_input_scale)
688+
elif self.flashinfer_moe_enabled:
689+
# NOTE: weights have to be swapped since the activation is
690+
# applied on different half for flashinfer vs vllm
691+
w13_weight = _swap_w13_to_w31(layer.w13_weight.data)
692+
w13_weight_scale_inv = _swap_w13_to_w31(
693+
layer.w13_weight_scale_inv.data)
694+
w2_weight = layer.w2_weight.data
695+
w2_weight_scale_inv = layer.w2_weight_scale_inv.data
677696
else:
678697
w13_weight = layer.w13_weight.data
679698
w13_weight_scale_inv = layer.w13_weight_scale_inv.data
@@ -915,25 +934,25 @@ def apply(
915934
assert logical_to_physical_map is not None
916935
assert logical_replica_count is not None
917936
assert isinstance(layer, FusedMoE)
918-
919-
topk_weights, topk_ids = FusedMoE.select_experts(
920-
hidden_states=x,
921-
router_logits=router_logits,
922-
use_grouped_topk=use_grouped_topk,
923-
top_k=top_k,
924-
renormalize=renormalize,
925-
topk_group=topk_group,
926-
num_expert_group=num_expert_group,
927-
custom_routing_function=custom_routing_function,
928-
scoring_func=scoring_func,
929-
e_score_correction_bias=e_score_correction_bias,
930-
indices_type=self.topk_indices_dtype,
931-
enable_eplb=enable_eplb,
932-
expert_map=expert_map,
933-
expert_load_view=expert_load_view,
934-
logical_to_physical_map=logical_to_physical_map,
935-
logical_replica_count=logical_replica_count,
936-
)
937+
if not self.flashinfer_moe_enabled:
938+
topk_weights, topk_ids = FusedMoE.select_experts(
939+
hidden_states=x,
940+
router_logits=router_logits,
941+
use_grouped_topk=use_grouped_topk,
942+
top_k=top_k,
943+
renormalize=renormalize,
944+
topk_group=topk_group,
945+
num_expert_group=num_expert_group,
946+
custom_routing_function=custom_routing_function,
947+
scoring_func=scoring_func,
948+
e_score_correction_bias=e_score_correction_bias,
949+
indices_type=self.topk_indices_dtype,
950+
enable_eplb=enable_eplb,
951+
expert_map=expert_map,
952+
expert_load_view=expert_load_view,
953+
logical_to_physical_map=logical_to_physical_map,
954+
logical_replica_count=logical_replica_count,
955+
)
937956

938957
if self.rocm_aiter_moe_enabled:
939958
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
@@ -971,6 +990,31 @@ def apply(
971990
apply_router_weight_on_input=apply_router_weight_on_input,
972991
global_num_experts=global_num_experts,
973992
expert_map=expert_map)
993+
elif self.flashinfer_moe_enabled:
994+
# Currently only work with DS models
995+
assert self.block_quant
996+
assert (renormalize and use_grouped_topk
997+
and scoring_func == 'sigmoid'
998+
and custom_routing_function is None)
999+
assert activation == "silu"
1000+
return torch.ops.vllm.flashinfer_fused_moe_blockscale_fp8(
1001+
routing_logits=router_logits.to(torch.float32),
1002+
routing_bias=e_score_correction_bias,
1003+
x=x,
1004+
w13_weight=layer.w13_weight,
1005+
w13_weight_scale_inv=layer.w13_weight_scale_inv,
1006+
w2_weight=layer.w2_weight,
1007+
w2_weight_scale_inv=layer.w2_weight_scale_inv,
1008+
global_num_experts=global_num_experts,
1009+
top_k=top_k,
1010+
num_expert_group=num_expert_group,
1011+
topk_group=topk_group,
1012+
intermediate_size=layer.intermediate_size_per_partition,
1013+
expert_offset=layer.ep_rank * layer.local_num_experts,
1014+
local_num_experts=layer.local_num_experts,
1015+
block_shape=self.quant_config.weight_block_size,
1016+
routed_scaling=1.0,
1017+
)
9741018
else:
9751019
return self.fused_experts(
9761020
hidden_states=x,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,7 +721,7 @@ def __init__(self, quant_config: ModelOptNvFp4Config):
721721
self.use_marlin = False
722722
self.allow_flashinfer_cutlass = False
723723

724-
if envs.VLLM_USE_FLASHINFER_MOE:
724+
if envs.VLLM_USE_FLASHINFER_MOE_FP4:
725725
if self.cutlass_nvfp4_supported and current_platform.is_cuda() \
726726
and current_platform.is_device_capability(100):
727727
logger.info_once(
@@ -800,10 +800,9 @@ def select_gemm_impl(self, prepare_finalize,
800800
assert moe.dp_size > 1
801801
logger.debug_once("Using CutlassExpertsFp4")
802802
# Currently CutlassExpertsFp4 doesn't support DP
803-
raise ValueError(
804-
"CutlassExpertsFp4 doesn't support DP. "
805-
"Use flashinfer CUTLASS FusedMoE(VLLM_USE_FLASHINFER_MOE)"
806-
" backend instead.")
803+
raise ValueError("CutlassExpertsFp4 doesn't support DP. "
804+
"Use flashinfer CUTLASS FusedMoE backend instead "
805+
"(set VLLM_USE_FLASHINFER_MOE_FP4=1)")
807806

808807
return experts
809808

vllm/utils/flashinfer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def wrapper(*args, **kwargs):
6464

6565

6666
# Create lazy wrappers for each function
67+
flashinfer_trtllm_fp8_block_scale_moe = _lazy_import_wrapper(
68+
"flashinfer.fused_moe", "trtllm_fp8_block_scale_moe")
6769
flashinfer_cutlass_fused_moe = _lazy_import_wrapper("flashinfer.fused_moe",
6870
"cutlass_fused_moe")
6971
fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize")
@@ -77,10 +79,16 @@ def wrapper(*args, **kwargs):
7779
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
7880

7981

82+
@functools.cache
83+
def has_flashinfer_moe() -> bool:
84+
"""Return ``True`` if FlashInfer MoE module is available."""
85+
return importlib.util.find_spec("flashinfer.fused_moe") is not None
86+
87+
8088
@functools.cache
8189
def has_flashinfer_cutlass_fused_moe() -> bool:
8290
"""Return ``True`` if FlashInfer CUTLASS fused MoE is available."""
83-
if not has_flashinfer():
91+
if not has_flashinfer_moe():
8492
return False
8593

8694
# Check if all required functions are available
@@ -99,9 +107,11 @@ def has_flashinfer_cutlass_fused_moe() -> bool:
99107

100108
__all__ = [
101109
"has_flashinfer",
102-
"has_flashinfer_cutlass_fused_moe",
110+
"flashinfer_trtllm_fp8_block_scale_moe",
103111
"flashinfer_cutlass_fused_moe",
104112
"fp4_quantize",
105113
"fp4_swizzle_blockscale",
106114
"autotune",
115+
"has_flashinfer_moe",
116+
"has_flashinfer_cutlass_fused_moe",
107117
]

0 commit comments

Comments
 (0)