Skip to content

Commit 26fc36b

Browse files
authored
[V1] MTP supports torchair (#2145)
### What this PR does / why we need it? Support MTP with: - [x] V0 Scheduler - [x] TorchAir - [x] Single DP - [x] Multi DP - [x] Disaggregate PD Known issues: - [ ] Not support V1 Scheduler (chunked prefill), will be supported in a few weeks - [ ] vllm v0.10.0 does not support metrics with `DP > 1` right now, need to comment out the line 171-175 in file `vllm/vllm/v1/metrics/loggers.py` ``` if (len(self.engine_indexes) > 1 and vllm_config.speculative_config is not None): raise NotImplementedError("Prometheus metrics with Spec Decoding " "with >1 EngineCore per AsyncLLM is not " "supported yet.") ``` To start an online server with torchair enabled, here is an example: ``` python -m vllm.entrypoints.openai.api_server \ --model="/weights/DeepSeek-R1_w8a8/" \ --trust-remote-code \ --max-model-len 40000 \ --tensor-parallel-size 4 \ --data_parallel_size 4 \ --max-num-seqs 16 \ --no-enable-prefix-caching \ --enable_expert_parallel \ --served-model-name deepseekr1 \ --speculative-config '{"num_speculative_tokens": 1, "method":"deepseek_mtp"}' \ --quantization ascend \ --host 0.0.0.0 \ --port 1234 \ --additional-config '{"ascend_scheduler_config":{"enabled":true,"enable_chunked_prefill":false},"torchair_graph_config":{"enabled":true,"graph_batch_sizes":[16]},"enable_weight_nz_layout":true}' \ --gpu_memory_utilization 0.9 ``` offline example with torchair enabled ``` from vllm import LLM, SamplingParams prompts = [ "Hello, my name is", "The president of the United States is", "The capital of France is", "The future of AI is", ] # Create a sampling params object. sampling_params = SamplingParams(max_tokens=16, temperature=0) # Create an LLM. llm = LLM( model="/home/data/DeepSeek-R1_w8a8/", tensor_parallel_size=16, max_num_seqs=16, gpu_memory_utilization=0.9, distributed_executor_backend="mp", enable_expert_parallel=True, speculative_config={ "method": "deepseek_mtp", "num_speculative_tokens": 1, }, trust_remote_code=True, enforce_eager=False, max_model_len=2000, additional_config = { 'torchair_graph_config': { 'enabled': True, "graph_batch_sizes": [16], 'enable_multistream_shared_expert': False, }, "ascend_scheduler_config": { "enabled": True }, # 'expert_tensor_parallel_size': 16, } ) # Generate texts from the prompts. # llm.start_profile() outputs = llm.generate(prompts, sampling_params) # llm.stop_profile() for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - vLLM version: v0.10.0 - vLLM main: vllm-project/vllm@302962e --------- Signed-off-by: xuyexiong <[email protected]>
1 parent bf84f2d commit 26fc36b

File tree

12 files changed

+542
-161
lines changed

12 files changed

+542
-161
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def test_ascend_mla_metadata_builder_default(self):
188188
runner.chunked_prefill_enabled = False
189189
runner.device = "cpu"
190190
runner.block_size = 16
191+
runner.decode_token_per_req = 1
191192

192193
ascend_config = MagicMock()
193194
ascend_config.torchair_graph_config = MagicMock()
@@ -206,6 +207,7 @@ def test_ascend_mla_metadata_builder_default(self):
206207
def test_reorder_batch_with_torchair_graph(self, ascend_config):
207208
runner = MagicMock()
208209
runner.chunked_prefill_enabled = False
210+
runner.decode_token_per_req = 1
209211
ascend_config.torchair_graph_config = MagicMock()
210212
ascend_config.torchair_graph_config.enabled = True
211213

@@ -238,6 +240,7 @@ def test_reorder_batch_without_torchair_graph(self):
238240
ascend_config = MagicMock()
239241
runner = MagicMock()
240242
runner.chunked_prefill_enabled = False
243+
runner.decode_token_per_req = 1
241244
ascend_config.torchair_graph_config = MagicMock()
242245
ascend_config.torchair_graph_config.enabled = False
243246
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
@@ -275,6 +278,7 @@ def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
275278
runner = MagicMock()
276279
runner.graph_block_tables = torch.zeros((8, 64), dtype=torch.int32)
277280
runner.chunked_prefill_enabled = False
281+
runner.decode_token_per_req = 1
278282
builder = AscendMLAMetadataBuilder(runner=runner)
279283
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
280284

@@ -291,6 +295,7 @@ def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
291295
runner = MagicMock()
292296
runner.graph_block_tables = torch.zeros((8, 4), dtype=torch.int32)
293297
runner.chunked_prefill_enabled = False
298+
runner.decode_token_per_req = 1
294299
builder = AscendMLAMetadataBuilder(runner=runner)
295300
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
296301

@@ -308,6 +313,7 @@ def test_get_graph_runner_block_tables_from_numpy(self,
308313
runner = MagicMock()
309314
runner.graph_block_tables = np.zeros((8, 64), dtype=np.int32)
310315
runner.chunked_prefill_enabled = False
316+
runner.decode_token_per_req = 1
311317
builder = AscendMLAMetadataBuilder(runner=runner)
312318

313319
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
@@ -332,6 +338,7 @@ def test_build_dummy(self, mock_ascend_config):
332338
runner.attn_mask = torch.zeros((1, 1), dtype=torch.bool)
333339
runner.spec_attn_mask = torch.zeros((1, 1), dtype=torch.bool)
334340
runner.dtype = torch.float16
341+
runner.decode_token_per_req = 1
335342

336343
builder = AscendMLAMetadataBuilder(runner=runner,
337344
metadata_cls=AscendMLAMetadata)

tests/ut/models/test_deepseek_mtp.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def setup_predictor(self, mocker: MockerFixture):
7777
mock_vllm_config.model_config = mock_model_config
7878
mock_vllm_config.cache_config = CacheConfig()
7979
mock_vllm_config.quant_config = mocker.MagicMock()
80+
mocker.patch(
81+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
82+
return_value=None)
8083
mocker.patch(
8184
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
8285
return_value=None)
@@ -90,10 +93,9 @@ def test_init(self, mocker: MockerFixture, setup_predictor):
9093
assert predictor.num_mtp_layers == 3
9194
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
9295

93-
@pytest.mark.parametrize('kv_caches, inputs_embeds', [
94-
(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]])),
95-
(None, None),
96-
])
96+
@pytest.mark.parametrize(
97+
'kv_caches, inputs_embeds',
98+
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
9799
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
98100
inputs_embeds):
99101
predictor = setup_predictor
@@ -147,6 +149,9 @@ def setup_mtp(self, mocker: MockerFixture):
147149
mocker.patch("torch.nn.Module.__setattr__")
148150
mocker.patch("torch.nn.Module.__getattr__")
149151
mocker.patch("torch.nn.Module.__delattr__")
152+
mocker.patch(
153+
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
154+
return_value=None)
150155
mocker.patch(
151156
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
152157
return_value=None)
@@ -172,4 +177,4 @@ def test_forward(self, mocker: MockerFixture, setup_mtp):
172177
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
173178
previous_hidden_states, inputs_embeds,
174179
spec_step_idx)
175-
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
180+
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

tests/ut/quantization/test_quant_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from vllm.attention.layer import Attention
55
from vllm.model_executor.layers.fused_moe import FusedMoE
6+
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
67
from vllm.model_executor.layers.linear import (LinearBase,
78
UnquantizedLinearMethod)
89

@@ -111,6 +112,7 @@ def test_get_quant_method_for_attention(self):
111112

112113
def test_get_quant_method_for_fused_moe(self):
113114
fused_moe_layer = MagicMock(spec=FusedMoE)
115+
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
114116

115117
# Test skipped layer
116118
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \

vllm_ascend/attention/attention_v1_torchair.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _get_graph_runner_block_tables(
156156
self, num_seqs: int, block_tables: torch.Tensor) -> torch.Tensor:
157157

158158
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
159-
assert max_batch_size >= num_seqs
159+
assert max_batch_size >= num_seqs, f"max_batch_size: {max_batch_size} should be bigger than cur_num_seqs: {num_seqs}"
160160

161161
if isinstance(self.runner.graph_block_tables, np.ndarray):
162162
graph_block_tables = torch.zeros((max_batch_size, max_blocks),
@@ -259,26 +259,34 @@ def build(self,
259259
if use_torchair_graph and self.runner.attn_state in [
260260
AscendAttentionState.DecodeOnly,
261261
]:
262+
num_reqs_pad_size = 0
263+
num_token_pad_size = 0
264+
if graph_pad_size != 0:
265+
pad_value = 0
266+
num_token_pad_size = graph_pad_size - num_actual_tokens
267+
num_reqs_pad_size = (
268+
graph_pad_size // self.runner.decode_token_per_req -
269+
num_reqs)
262270
pad_value = 1
263271
padded_seq_lens = seq_lens.tolist() + [pad_value
264-
] * graph_pad_size
272+
] * num_reqs_pad_size
265273

266274
seq_lens = torch.from_numpy(
267275
np.array(padded_seq_lens).astype(np.int32))
268-
padding = torch.full((graph_pad_size, ),
276+
padding = torch.full((num_token_pad_size, ),
269277
PAD_SLOT_ID,
270278
dtype=slot_mapping.dtype,
271279
device=slot_mapping.device)
272280
slot_mapping = torch.cat([slot_mapping, padding])
273281
block_table_padding = torch.zeros(
274-
(graph_pad_size, ) + block_table.shape[1:],
282+
(num_reqs_pad_size, ) + block_table.shape[1:],
275283
dtype=block_table.dtype,
276284
device=block_table.device)
277285
block_table = torch.cat([block_table, block_table_padding],
278286
dim=0)
279287
block_table = self._get_graph_runner_block_tables(
280-
num_seqs + graph_pad_size, block_table)
281-
padding_0 = torch.zeros(graph_pad_size,
288+
num_seqs + num_reqs_pad_size, block_table)
289+
padding_0 = torch.zeros(num_token_pad_size,
282290
dtype=input_positions.dtype,
283291
device=input_positions.device)
284292
input_positions = torch.cat([input_positions, padding_0])

0 commit comments

Comments
 (0)