Skip to content

Commit 67c153b

Browse files
authored
Fix Llama4 FlashInfer FP4 MoE issues (#22511)
Signed-off-by: Po-Han Huang <[email protected]>
1 parent f7ad6a1 commit 67c153b

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,6 @@ def apply(
170170
"w1_scale and w2_scale must not "
171171
"be None for FlashInferExperts")
172172

173-
assert not apply_router_weight_on_input
174-
175173
quant_scales = [
176174
a1_gscale,
177175
w1_scale.view(torch.int32),

vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ def prepare(
6060
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor],
6161
Optional[torch.Tensor], Optional[torch.Tensor]]:
6262

63-
assert not apply_router_weight_on_input
63+
if apply_router_weight_on_input:
64+
topk = topk_ids.size(1)
65+
# TODO: this only works for topK=1, will need to update for topK>1
66+
assert topk == 1, \
67+
"apply_router_weight_on_input is only implemented for topk=1"
68+
a1.mul_(topk_weights.to(a1.dtype))
6469

6570
(a1_gscale, use_dp, local_tokens) = extract_required_args(
6671
extra_prepare_args, ['a1_gscale', 'use_dp', 'local_tokens'])

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,8 +1299,9 @@ def apply(
12991299
output2_scale_scalar=layer.g2_alphas.data,
13001300
num_experts=global_num_experts,
13011301
top_k=top_k,
1302-
n_group=num_expert_group,
1303-
topk_group=topk_group,
1302+
n_group=num_expert_group
1303+
if num_expert_group is not None else 0,
1304+
topk_group=topk_group if topk_group is not None else 0,
13041305
intermediate_size=layer.intermediate_size_per_partition,
13051306
local_expert_offset=layer.ep_rank * layer.local_num_experts,
13061307
local_num_experts=layer.local_num_experts,

0 commit comments

Comments
 (0)