Skip to content

Commit 58959a1

Browse files
committed
[Feat]mtp aclgraph support
Signed-off-by: anon189Ty <[email protected]>
1 parent f2d8493 commit 58959a1

File tree

6 files changed

+26
-10
lines changed

6 files changed

+26
-10
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.nn as nn
2424
from transformers import PretrainedConfig
2525
from vllm.attention.backends.abstract import AttentionMetadata
26+
from vllm.compilation.decorators import support_torch_compile
2627
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
2728
get_current_vllm_config)
2829
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -177,6 +178,7 @@ def compute_logits(
177178
return logits
178179

179180

181+
@support_torch_compile
180182
class CustomDeepSeekMTP(DeepSeekMTP):
181183

182184
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.nn as nn
77
from vllm.attention.layer import Attention
8-
from vllm.config import (CompilationLevel, VllmConfig,
8+
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
99
get_layers_from_vllm_config)
1010
from vllm.distributed.parallel_state import get_pp_group
1111
from vllm.logger import logger
@@ -115,7 +115,9 @@ def dummy_run(self,
115115
with_prefill: bool = False,
116116
skip_attn: bool = False,
117117
num_reqs: int = 0,
118-
num_tokens_across_dp: Optional[torch.Tensor] = None):
118+
num_tokens_across_dp: Optional[torch.Tensor] = None,
119+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
120+
batch_descriptor=None):
119121
moe_comm_type = self.runner._select_moe_comm_method(
120122
num_tokens, with_prefill)
121123
with set_ascend_forward_context(None,

vllm_ascend/spec_decode/interface.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import Optional
33

44
import torch
5-
from vllm.config import VllmConfig
5+
from vllm.config import CUDAGraphMode, VllmConfig
66
from vllm.v1.core.sched.output import SchedulerOutput
77
from vllm.v1.sample.metadata import SamplingMetadata
88
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -33,7 +33,9 @@ def dummy_run(self,
3333
with_prefill: bool = False,
3434
skip_attn: bool = False,
3535
num_reqs: int = 0,
36-
num_tokens_across_dp: Optional[torch.Tensor] = None):
36+
num_tokens_across_dp: Optional[torch.Tensor] = None,
37+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
38+
batch_descriptor=None):
3739
"""Called by dummy_run in modle_runner"""
3840
raise NotImplementedError
3941

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torchair
66
from torchair import patch_for_hcom
77
from vllm.attention.layer import Attention
8-
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
9-
set_current_vllm_config)
8+
from vllm.config import (CUDAGraphMode, VllmConfig,
9+
get_layers_from_vllm_config, set_current_vllm_config)
1010
from vllm.forward_context import BatchDescriptor, get_forward_context
1111
from vllm.model_executor.model_loader import get_model_loader
1212
from vllm.model_executor.model_loader.utils import (
@@ -110,7 +110,9 @@ def dummy_run(self,
110110
with_prefill: bool = False,
111111
skip_attn: bool = False,
112112
num_reqs: int = 0,
113-
num_tokens_across_dp=None) -> None:
113+
num_tokens_across_dp=None,
114+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
115+
batch_descriptor=None) -> None:
114116
if not self.torchair_graph_enabled:
115117
# TODO: adapt enable_dbo later
116118
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -152,7 +154,9 @@ def dummy_run(self,
152154
reserved_mc2_mask=self.runner.reserved_mc2_mask,
153155
moe_comm_type=moe_comm_type,
154156
in_profile_run=self.runner.in_profile_run,
155-
num_actual_tokens=0):
157+
num_actual_tokens=0,
158+
aclgraph_runtime_mode=aclgraph_runtime_mode,
159+
batch_descriptor=batch_descriptor):
156160
if is_running_torchair:
157161
assert attn_metadata is not None
158162
torch._dynamo.mark_static(input_ids)
@@ -446,6 +450,7 @@ def _propose(
446450
reserved_mc2_mask=self.runner.reserved_mc2_mask,
447451
moe_comm_type=moe_comm_type,
448452
aclgraph_runtime_mode=aclgraph_runtime_mode,
453+
batch_descriptor=batch_descriptor,
449454
in_profile_run=self.runner.in_profile_run,
450455
num_actual_tokens=num_tokens):
451456
with ProfileExecuteDuration().capture_async('mtp_forward'):

vllm_ascend/spec_decode/ngram_proposer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from vllm.config import CUDAGraphMode
23
from vllm.v1.spec_decode.ngram_proposer import \
34
NgramProposer as VllmNgramProposer
45

@@ -23,7 +24,9 @@ def dummy_run(self,
2324
with_prefill=None,
2425
skip_attn=None,
2526
num_reqs=None,
26-
num_tokens_across_dp=None):
27+
num_tokens_across_dp=None,
28+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
29+
batch_descriptor=None):
2730
pass
2831

2932
def generate_token_ids(self,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2484,7 +2484,9 @@ def dummy_compute_logits(hidden_states):
24842484
with_prefill=with_prefill,
24852485
skip_attn=True,
24862486
num_reqs=num_reqs,
2487-
num_tokens_across_dp=num_tokens_across_dp)
2487+
num_tokens_across_dp=num_tokens_across_dp,
2488+
aclgraph_runtime_mode=aclgraph_runtime_mode,
2489+
batch_descriptor=batch_descriptor)
24882490
if need_dummy_logits:
24892491
dummy_compute_logits(hidden_states)
24902492
if self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)