Skip to content

Commit 001e50c

Browse files
authored
[Model] MTP fallback to eager for DeepSeek v32 (#25982)
Signed-off-by: Lu Fang <[email protected]>
1 parent 96ebcaa commit 001e50c

File tree

5 files changed

+32
-5
lines changed

5 files changed

+32
-5
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,19 @@ def test_load_model(mock_get_model, mock_get_layers, mock_get_pp_group, method,
337337
"target_attn_1": mock.MagicMock(),
338338
"target_attn_2": mock.MagicMock()
339339
}
340+
target_indx_layers: dict[str, mock.MagicMock] = {}
340341
# Draft model has one extra attention layer compared to target model
341342
all_attn_layers = {
342343
**target_attn_layers, "draft_extra_attn": mock.MagicMock()
343344
}
344345

346+
all_indx_layers: dict[str, mock.MagicMock] = {}
347+
345348
# Make mock_get_layers return different values for each call
346-
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
349+
mock_get_layers.side_effect = [
350+
target_attn_layers, target_indx_layers, all_attn_layers,
351+
all_indx_layers
352+
]
347353

348354
# Setup mock for pp group to return the appropriate value for world size
349355
mock_pp_group = mock.MagicMock()
@@ -658,6 +664,9 @@ def create_deterministic_logits(token_ids, k: int):
658664
# Mock runner for attention metadata building.
659665
proposer.runner = mock.MagicMock()
660666
proposer.runner.attn_groups.append([mock.MagicMock()])
667+
proposer.runner.attn_groups[0][0].metadata_builders = [
668+
attn_metadata_builder
669+
]
661670
proposer.runner.attn_groups[0][0].get_metadata_builder.return_value = \
662671
attn_metadata_builder
663672
proposer._get_attention_metadata_builder = mock.MagicMock(

tests/v1/spec_decode/test_mtp.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers,
6363

6464
target_attn_layers = {"target_attn_1": mock.MagicMock()}
6565
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}
66-
mock_get_layers.side_effect = [target_attn_layers, all_attn_layers]
66+
target_indexer_layers: dict = {}
67+
all_indexer_layers: dict = {}
68+
69+
mock_get_layers.side_effect = [
70+
target_attn_layers, target_indexer_layers, all_attn_layers,
71+
all_indexer_layers
72+
]
6773

6874
mock_pp_group = mock.MagicMock()
6975
mock_pp_group.world_size = 1

vllm/config/speculative.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
@dataclass
4242
class SpeculativeConfig:
4343
"""Configuration for speculative decoding."""
44-
44+
enforce_eager: Optional[bool] = None
45+
"""Override the default enforce_eager from model_config"""
4546
# General speculative decoding control
4647
num_speculative_tokens: SkipValidation[int] = None # type: ignore
4748
"""The number of speculative tokens, if provided. It will default to the
@@ -219,6 +220,11 @@ def __post_init__(self):
219220
assert (
220221
self.target_model_config
221222
is not None), "target_model_config must be present for mtp"
223+
if self.target_model_config.hf_text_config.model_type \
224+
== "deepseek_v32":
225+
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
226+
# remove this when the issue is fixed.
227+
self.enforce_eager = True
222228
# use the draft model from the same model:
223229
self.model = self.target_model_config.model
224230
# Align the quantization of draft model for cases such as

vllm/v1/attention/backends/mla/indexer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):
171171

172172
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
173173
cudagraph_support: ClassVar[AttentionCGSupport] = \
174-
AttentionCGSupport.UNIFORM_BATCH
174+
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
175175

176176
reorder_batch_threshold: int = 1
177177

vllm/v1/spec_decode/eagle.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __init__(
5050
):
5151
self.vllm_config = vllm_config
5252
self.speculative_config = vllm_config.speculative_config
53+
assert self.speculative_config is not None
5354
self.draft_model_config = self.speculative_config.draft_model_config
5455
self.method = self.speculative_config.method
5556

@@ -74,11 +75,16 @@ def __init__(
7475
vllm_config.model_config)
7576

7677
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
78+
self.draft_indexer_metadata_builder: Optional[
79+
AttentionMetadataBuilder] = None
80+
self.attn_layer_names: list[str] = []
81+
self.indexer_layer_names: list[str] = []
7782

7883
self.use_cuda_graph = (not current_platform.is_xpu()
7984
and self.vllm_config.compilation_config.level
8085
== CompilationLevel.PIECEWISE and
81-
not self.vllm_config.model_config.enforce_eager)
86+
not self.vllm_config.model_config.enforce_eager
87+
and not self.speculative_config.enforce_eager)
8288
self.cudagraph_batch_sizes = list(
8389
reversed(self.vllm_config.compilation_config.
8490
cudagraph_capture_sizes)) if self.use_cuda_graph else []

0 commit comments

Comments
 (0)