Skip to content

Commit 26b5ec8

Browse files
benchislettsouthfreebird
authored andcommitted
[Bugfix] Allow skipping MoE in NVFP4 (fix for MTP) (vllm-project#25987)
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent 43548ce commit 26b5ec8

File tree

5 files changed

+18
-5
lines changed

5 files changed

+18
-5
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,6 +1194,8 @@ def __init__(
11941194
if quant_config is None
11951195
else quant_config.get_quant_method(self, prefix)
11961196
)
1197+
if quant_method is None:
1198+
quant_method = UnquantizedFusedMoEMethod(moe)
11971199

11981200
assert quant_method is not None
11991201
assert isinstance(quant_method, FusedMoEMethodBase)

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,9 @@ def get_quant_method(
884884
) -> Optional["QuantizeMethodBase"]:
885885
from vllm.attention.layer import Attention # Avoid circular import
886886

887+
skip_layer = self.is_layer_excluded(prefix)
887888
if isinstance(layer, LinearBase):
888-
if self.is_layer_excluded(prefix):
889+
if skip_layer:
889890
return UnquantizedLinearMethod()
890891
# Check if this is a vision model layer that should not be quantized
891892
if "vision_tower" in prefix or "vision_model" in prefix:
@@ -894,6 +895,8 @@ def get_quant_method(
894895
elif isinstance(layer, Attention):
895896
return ModelOptFp8KVCacheMethod(self)
896897
elif isinstance(layer, FusedMoE):
898+
if skip_layer:
899+
return None
897900
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
898901
return None
899902

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(
5555
DeepseekV2DecoderLayer(
5656
vllm_config,
5757
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
58+
config=self.config,
5859
)
5960
for i in range(self.config.num_hidden_layers)
6061
]

vllm/model_executor/models/deepseek_mtp.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
4848
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
4949
super().__init__()
5050

51-
config = vllm_config.model_config.hf_config
51+
config = vllm_config.speculative_config.draft_model_config.hf_config
52+
self.config = config
5253
quant_config = vllm_config.quant_config
5354

5455
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -66,11 +67,15 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
6667
)
6768
else:
6869
topk_indices_buffer = None
70+
6971
self.shared_head = SharedHead(
7072
config=config, prefix=prefix, quant_config=quant_config
7173
)
7274
self.mtp_block = DeepseekV2DecoderLayer(
73-
vllm_config, prefix, topk_indices_buffer
75+
vllm_config,
76+
prefix,
77+
config=self.config,
78+
topk_indices_buffer=topk_indices_buffer,
7479
)
7580

7681
def forward(

vllm/model_executor/models/deepseek_v2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,11 +1055,13 @@ def __init__(
10551055
self,
10561056
vllm_config: VllmConfig,
10571057
prefix: str,
1058+
config: Optional[DeepseekV2Config] = None,
10581059
topk_indices_buffer: Optional[torch.Tensor] = None,
10591060
) -> None:
10601061
super().__init__()
10611062

1062-
config = vllm_config.model_config.hf_config
1063+
if config is None:
1064+
config = vllm_config.model_config.hf_config
10631065
model_config = vllm_config.model_config
10641066
cache_config = vllm_config.cache_config
10651067
quant_config = vllm_config.quant_config
@@ -1200,7 +1202,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
12001202
self.start_layer, self.end_layer, self.layers = make_layers(
12011203
config.num_hidden_layers,
12021204
lambda prefix: DeepseekV2DecoderLayer(
1203-
vllm_config, prefix, topk_indices_buffer
1205+
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
12041206
),
12051207
prefix=f"{prefix}.layers",
12061208
)

0 commit comments

Comments
 (0)