Skip to content

Commit ea6ae8c

Browse files
authored
[Bugfix] Fix marlin moe fallback logic for llama4 (#18042)
Signed-off-by: mgoin <[email protected]>
1 parent 2ff297d commit ea6ae8c

File tree

3 files changed

+12
-5
lines changed

3 files changed

+12
-5
lines changed

tests/weight_loading/models-large.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
44
compressed-tensors, nm-testing/test-w4a16-mixtral-actorder-group, main
55
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
66
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, gptq-8bit-128g-actorder_True
7-
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
7+
awq_marlin, casperhansen/deepseek-coder-v2-instruct-awq, main
8+
compressed-tensors, RedHatAI/Llama-4-Scout-17B-16E-Instruct-quantized.w4a16, main

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,7 @@ def __init__(
480480
self.custom_routing_function = custom_routing_function
481481
self.scoring_func = scoring_func
482482
self.e_score_correction_bias = e_score_correction_bias
483+
self.apply_router_weight_on_input = apply_router_weight_on_input
483484
self.activation = activation
484485

485486
if self.scoring_func != "softmax" and not self.use_grouped_topk:
@@ -498,7 +499,6 @@ def __init__(
498499
self.quant_method = quant_config.get_quant_method(self, prefix)
499500
assert self.quant_method is not None
500501

501-
self.apply_router_weight_on_input = apply_router_weight_on_input
502502
moe_quant_params = {
503503
"num_experts": self.local_num_experts,
504504
"hidden_size": hidden_size,

vllm/model_executor/layers/quantization/utils/marlin_utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,19 @@ def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) \
171171
-> bool:
172172
hidden_size = layer.hidden_size
173173
intermediate_size_per_partition = layer.intermediate_size_per_partition
174+
# apply_router_weight_on_input is not supported for moe marlin
175+
supports_router_weight = not layer.apply_router_weight_on_input
176+
# moe marlin requires the activation to be silu
177+
supports_activation = layer.activation == "silu"
174178

175179
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
176180
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
177181
# moe marlin requires n % 128 == 0 and k % 64 == 0
178-
return hidden_size % 128 == 0 and \
179-
intermediate_size_per_partition % max(64, group_size) == 0 and \
180-
group_size in [-1, 32, 64, 128]
182+
supports_shape = hidden_size % 128 == 0 and \
183+
intermediate_size_per_partition % max(64, group_size) == 0
184+
supports_group_size = group_size in [-1, 32, 64, 128]
185+
return supports_shape and supports_group_size and \
186+
supports_router_weight and supports_activation
181187

182188

183189
def marlin_make_workspace(output_size_per_partition: int,

0 commit comments

Comments
 (0)