Skip to content

Commit 3629bc4

Browse files
ZhaoJiangJiang赵江江
andauthored
feat: add mtp ut and fix some bugs (#2453)
### What this PR does / why we need it? Fix mtp mode ut ### Does this PR introduce _any_ user-facing change? Nothing ### How was this patch tested? This can be tested in the same way as a unit test. - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@5341565 Signed-off-by: 赵江江 <[email protected]> Co-authored-by: 赵江江 <[email protected]>
1 parent dd04a96 commit 3629bc4

File tree

10 files changed

+129
-75
lines changed

10 files changed

+129
-75
lines changed

tests/e2e/singlecard/spec_decode_v1/test_v1_mtp_correctness.py

Lines changed: 39 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,13 @@
11
from __future__ import annotations
22

3-
import random
4-
from typing import Any
3+
import os
54

65
import pytest
7-
from vllm import LLM, SamplingParams
6+
from vllm import SamplingParams
87

8+
from tests.e2e.conftest import VllmRunner
99

10-
@pytest.fixture
11-
def test_prompts():
12-
prompt_types = ["repeat", "sentence"]
13-
num_prompts = 10
14-
prompts = []
15-
16-
random.seed(0)
17-
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
18-
19-
# Generate a mixed batch of prompts, some of which can be easily
20-
# predicted by n-gram matching and some which likely cannot.
21-
for kind in random_prompt_type_choices:
22-
word_choices = ["test", "temp", "hello", "where"]
23-
word = random.choice(word_choices)
24-
if kind == "repeat":
25-
prompt = f"""
26-
please repeat the word '{word}' 10 times.
27-
give no other output than the word at least ten times in a row,
28-
in lowercase with spaces between each word and without quotes.
29-
"""
30-
elif kind == "sentence":
31-
prompt = f"""
32-
please give a ten-word sentence that
33-
uses the word {word} at least once.
34-
give no other output than that simple sentence without quotes.
35-
"""
36-
else:
37-
raise ValueError(f"Unknown prompt type: {kind}")
38-
prompts.append([{"role": "user", "content": prompt}])
39-
40-
return prompts
10+
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
4111

4212

4313
@pytest.fixture
@@ -50,39 +20,56 @@ def model_name():
5020
return "wemaster/deepseek_mtp_main_random_bf16"
5121

5222

53-
@pytest.mark.skipif(
54-
True, reason="TODO: Enable me after test_mtp_correctness is fixed")
5523
def test_mtp_correctness(
56-
test_prompts: list[list[dict[str, Any]]],
5724
sampling_config: SamplingParams,
5825
model_name: str,
5926
):
27+
example_prompts = [
28+
"Hello, my name is",
29+
"The president of the United States is",
30+
"The capital of France is",
31+
"The future of AI is",
32+
]
6033
'''
6134
Compare the outputs of a original LLM and a speculative LLM
6235
should be the same when using mtp speculative decoding.
6336
'''
64-
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
65-
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
66-
del ref_llm
37+
with VllmRunner(model_name,
38+
tensor_parallel_size=1,
39+
gpu_memory_utilization=0.7,
40+
max_model_len=256,
41+
enforce_eager=True) as ref_llm:
42+
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
43+
44+
with VllmRunner(
45+
model_name,
46+
tensor_parallel_size=1,
47+
max_num_seqs=256,
48+
gpu_memory_utilization=0.7,
49+
distributed_executor_backend="mp",
50+
enable_expert_parallel=True,
51+
speculative_config={
52+
"method": "deepseek_mtp",
53+
"num_speculative_tokens": 1,
54+
},
55+
enforce_eager=True,
56+
max_model_len=2000,
57+
additional_config={"ascend_scheduler_config": {
58+
"enabled": False
59+
}}) as spec_llm:
60+
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
6761

68-
spec_llm = LLM(model=model_name,
69-
trust_remote_code=True,
70-
speculative_config={
71-
"method": "deepseek_mtp",
72-
"num_speculative_tokens": 1,
73-
},
74-
max_model_len=256,
75-
enforce_eager=True)
76-
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
7762
matches = 0
7863
misses = 0
7964
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
80-
if ref_output.outputs[0].text == spec_output.outputs[0].text:
65+
ref_token_ids = ref_output[0][0]
66+
spec_token_ids = spec_output[0][0]
67+
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
8168
matches += 1
8269
else:
8370
misses += 1
84-
print(f"ref_output: {ref_output.outputs[0].text}")
85-
print(f"spec_output: {spec_output.outputs[0].text}")
71+
print(f"ref_output: {ref_output[1][0]}")
72+
print(f"spec_output: {spec_output[1][0]}")
8673

8774
# Heuristic: expect at least 66% of the prompts to match exactly
8875
# Upon failure, inspect the outputs to check for inaccuracy.

tests/ut/quantization/test_quant_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def test_get_quant_method_for_attention(self):
113113
def test_get_quant_method_for_fused_moe(self):
114114
fused_moe_layer = MagicMock(spec=FusedMoE)
115115
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
116+
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
116117

117118
# Test skipped layer
118119
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \

tests/ut/torchair/test_torchair_mla.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from unittest.mock import MagicMock, patch
22

33
import torch
4+
from torch import nn
45
from vllm.distributed.parallel_state import GroupCoordinator
56
from vllm.model_executor.layers.linear import LinearBase
67

78
from tests.ut.base import TestBase
89
from vllm_ascend.attention.attention_v1 import AscendAttentionState
10+
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
911
from vllm_ascend.torchair.torchair_mla import (
1012
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
1113
AscendMLATorchairImpl, AscendMLATorchairMetadata,
@@ -398,6 +400,68 @@ def test_build_dummy(self, mock_ascend_config):
398400
assert torch.equal(sin_golden, metadata.decode.sin)
399401
assert torch.equal(cos_golden, metadata.decode.cos)
400402

403+
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
404+
def test_build_decode(self, mock_ascend_config):
405+
ascend_config = MagicMock()
406+
mock_ascend_config.return_value = ascend_config
407+
ascend_config.torchair_graph_config.enabled = False
408+
409+
mock_vllm_config = MagicMock()
410+
mock_vllm_config.model_config.max_model_len = 1024
411+
mock_vllm_config.cache_config.block_size = 16
412+
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
413+
mock_vllm_config.get_head_size.return_value = 64
414+
mock_vllm_config.model_config.dtype = torch.float16
415+
mock_device = 'cpu'
416+
model = MagicMock(spec=nn.Module)
417+
model.model = MagicMock(spec=nn.Module)
418+
419+
builder = AscendMLATorchairMetadataBuilder(
420+
mock_vllm_config,
421+
mock_device,
422+
metadata_cls=AscendMLATorchairMetadata)
423+
builder.rope_dim = 64
424+
425+
builder.sin_cache = torch.tensor([10, 10])
426+
builder.cos_cache = torch.tensor([10, 10])
427+
428+
with patch.object(builder,
429+
"_get_graph_runner_block_tables",
430+
side_effect=lambda x, y: y):
431+
common_attn_metadata = AscendCommonAttentionMetadata(
432+
query_start_loc=torch.tensor([0, 1, 2, 3]),
433+
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
434+
seq_lens_cpu=torch.tensor([1, 1, 1]),
435+
num_reqs=3,
436+
num_actual_tokens=3,
437+
max_query_len=1,
438+
decode_token_per_req=torch.tensor([1, 1, 1]),
439+
block_table_tensor=torch.zeros((10, 10)),
440+
slot_mapping_cpu=torch.tensor(range(20)),
441+
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
442+
positions=torch.tensor([1, 1]),
443+
attn_mask=torch.ones((15, 15)),
444+
spec_attn_mask=None,
445+
attn_state=AscendAttentionState.ChunkedPrefill)
446+
447+
metadata = builder.build(common_attn_metadata, model)
448+
449+
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
450+
self.assertEqual(metadata.num_input_tokens, 0)
451+
self.assertEqual(metadata.num_actual_tokens, 3)
452+
self.assertEqual(metadata.num_decodes, 3)
453+
self.assertEqual(metadata.num_decode_tokens, 3)
454+
self.assertEqual(metadata.num_prefills, 0)
455+
self.assertEqual(metadata.attn_state,
456+
AscendAttentionState.ChunkedPrefill)
457+
self.assertIsNone(metadata.prefill)
458+
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
459+
self.assertEqual(metadata.block_tables.shape[0], 3)
460+
self.assertEqual(metadata.block_tables.shape[1], 10)
461+
self.assertEqual(metadata.seq_lens.shape[0], 3)
462+
self.assertEqual(metadata.slot_mapping.shape[0], 3)
463+
self.assertEqual(metadata.query_start_loc.shape[0], 4)
464+
401465

402466
class TestAscendMLATorchairImpl(TestBase):
403467

vllm_ascend/attention/mla_v1.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,18 +374,12 @@ def build(
374374

375375
decode_metadata = None
376376
if num_decodes > 0:
377-
actual_seq_lengths_q = query_start_loc[1:].tolist()
377+
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
378378
max_seq_lens = seq_lens[:num_decodes].max().item()
379379
seq_lens = seq_lens[:num_decode_tokens]
380380
input_positions = input_positions[:num_decode_tokens]
381381
block_table = block_table[:num_decode_tokens, ...]
382382
seq_lens_list = seq_lens.tolist()
383-
# TODO(xyx): whether this block is necessary without torchair
384-
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
385-
batch_size = slot_mapping.size(0)
386-
if actual_seq_lengths_q[-1] != batch_size \
387-
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
388-
actual_seq_lengths_q[-1] = batch_size
389383

390384
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
391385
1).unsqueeze(2)

vllm_ascend/models/deepseek_mtp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,4 +215,4 @@ def forward(
215215
hidden_states = self.model(input_ids, positions, kv_caches,
216216
attn_metadata, previous_hidden_states,
217217
inputs_embeds, spec_step_idx)
218-
return hidden_states
218+
return hidden_states

vllm_ascend/ops/fused_moe.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,7 @@ def __init__(
11781178
if self.scoring_func != "softmax" and not self.use_grouped_topk:
11791179
raise ValueError("Only softmax scoring function is supported for "
11801180
"non-grouped topk.")
1181-
self.moe = FusedMoEConfig.make(
1181+
moe = FusedMoEConfig.make(
11821182
num_experts=self.global_num_experts,
11831183
experts_per_token=top_k,
11841184
hidden_dim=hidden_size,
@@ -1188,8 +1188,10 @@ def __init__(
11881188
in_dtype=params_dtype,
11891189
quant_config=quant_config)
11901190

1191+
self.moe_config = moe
1192+
11911193
if quant_config is None:
1192-
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe)
1194+
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
11931195
else:
11941196
self.quant_method = quant_config.get_quant_method(self, prefix)
11951197

vllm_ascend/quantization/quant_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def get_quant_method(self, layer: torch.nn.Module,
102102
elif isinstance(layer, FusedMoE):
103103
if self.is_layer_skipped_ascend(prefix,
104104
self.packed_modules_mapping):
105-
return AscendUnquantizedFusedMoEMethod(layer.moe)
105+
return AscendUnquantizedFusedMoEMethod(layer.moe_config)
106106
return AscendFusedMoEMethod(self, prefix,
107107
self.packed_modules_mapping)
108108
elif isinstance(layer, VocabParallelEmbedding):

vllm_ascend/torchair/torchair_mla.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -492,17 +492,17 @@ def build(
492492
graph_pad_size = common_attn_metadata.graph_pad_size
493493
use_torchair_graph = graph_pad_size != -1
494494
if num_decodes > 0:
495-
actual_seq_lengths_q = query_start_loc[1:].tolist()
495+
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].tolist()
496496
max_seq_lens = seq_lens[:num_decodes].max().item()
497497
seq_lens = seq_lens[:num_decode_tokens]
498498
input_positions = input_positions[:num_decode_tokens]
499499
block_table = block_table[:num_decode_tokens, ...]
500+
num_token_pad_size = 0
500501
if use_torchair_graph and common_attn_metadata.attn_state in [
501502
AscendAttentionState.DecodeOnly,
502503
AscendAttentionState.SpecDecoding
503504
]:
504505
num_reqs_pad_size = 0
505-
num_token_pad_size = 0
506506
if graph_pad_size != 0:
507507
pad_value = 0
508508
num_token_pad_size = graph_pad_size - num_decode_tokens
@@ -535,13 +535,14 @@ def build(
535535
device=input_positions.device)
536536
input_positions = torch.cat(
537537
[input_positions, position_padding])
538-
actual_seq_lengths_q = query_start_loc[1:].tolist(
539-
) + common_attn_metadata.actual_seq_lengths_q[
540-
num_reqs:num_reqs + num_reqs_pad_size]
538+
actual_seq_lengths_q = (
539+
actual_seq_lengths_q + common_attn_metadata.
540+
actual_seq_lengths_q[num_reqs:num_reqs +
541+
num_reqs_pad_size])
541542
else:
542543
seq_lens_list = seq_lens.tolist()
543544
# mtp torchair + PD scenario, last element of actual_seq_lengths_q must equal to batch_size(num_tokens)
544-
batch_size = slot_mapping.size(0)
545+
batch_size = num_decode_tokens + num_token_pad_size
545546
if actual_seq_lengths_q[-1] != batch_size \
546547
and common_attn_metadata.attn_state == AscendAttentionState.SpecDecoding:
547548
actual_seq_lengths_q[-1] = batch_size

vllm_ascend/worker/mtp_proposer_v1.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,11 +190,6 @@ def propose(
190190
self.positions[:num_tokens] = target_positions
191191
self.hidden_states[:num_tokens] = target_hidden_states
192192

193-
if attn_metadata.prefill is not None:
194-
attn_metadata.prefill.query_lens = query_lens.cpu()
195-
attn_metadata.prefill.input_positions = target_positions
196-
attn_metadata.prefill.seq_lens = seq_lens
197-
198193
if not self.torchair_graph_enabled:
199194
# torch mode need to update num_tokens_across_dp
200195
# TODO: adapt enable_dbo later
@@ -213,6 +208,7 @@ def propose(
213208
num_tokens=num_input_tokens,
214209
with_prefill=with_prefill,
215210
num_tokens_across_dp=num_tokens_across_dp,
211+
reserved_mc2_mask=self.runner.reserved_mc2_mask,
216212
in_profile_run=self.runner.in_profile_run,
217213
num_actual_tokens=num_tokens):
218214
with ProfileExecuteDuration().capture_async('mtp_forward'):
@@ -315,6 +311,7 @@ def dummy_run(self,
315311
num_tokens=num_tokens,
316312
with_prefill=with_prefill,
317313
num_tokens_across_dp=num_tokens_across_dp,
314+
reserved_mc2_mask=self.runner.reserved_mc2_mask,
318315
in_profile_run=self.runner.in_profile_run,
319316
num_actual_tokens=0):
320317
if is_running_torchair:

vllm_ascend/worker/worker_v1.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,14 @@
4747
from vllm_ascend.platform import NPUPlatform
4848
from vllm_ascend.utils import (init_ascend_soc_version,
4949
register_ascend_customop, sleep_mode_enabled,
50-
try_register_lib)
50+
try_register_lib, vllm_version_is)
5151
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
5252

53+
if not vllm_version_is("0.10.1.1"):
54+
from vllm.v1.outputs import DraftTokenIds
55+
else:
56+
DraftTokenIds = None
57+
5358

5459
class NPUWorker(WorkerBase):
5560

@@ -343,3 +348,6 @@ def get_supported_pooling_tasks(self):
343348

344349
def get_supported_tasks(self) -> "tuple[SupportedTask, ...]":
345350
return self.model_runner.get_supported_tasks()
351+
352+
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
353+
return self.model_runner.take_draft_token_ids()

0 commit comments

Comments
 (0)