Skip to content

Commit fd1994e

Browse files
committed
[WIP]mtp aclgraph support
Signed-off-by: anon189Ty <[email protected]>
1 parent f2d8493 commit fd1994e

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.nn as nn
2424
from transformers import PretrainedConfig
2525
from vllm.attention.backends.abstract import AttentionMetadata
26+
from vllm.compilation.decorators import support_torch_compile
2627
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
2728
get_current_vllm_config)
2829
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -177,6 +178,7 @@ def compute_logits(
177178
return logits
178179

179180

181+
@support_torch_compile
180182
class CustomDeepSeekMTP(DeepSeekMTP):
181183

182184
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torchair
66
from torchair import patch_for_hcom
77
from vllm.attention.layer import Attention
8-
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
9-
set_current_vllm_config)
8+
from vllm.config import (CUDAGraphMode, VllmConfig,
9+
get_layers_from_vllm_config, set_current_vllm_config)
1010
from vllm.forward_context import BatchDescriptor, get_forward_context
1111
from vllm.model_executor.model_loader import get_model_loader
1212
from vllm.model_executor.model_loader.utils import (
@@ -110,7 +110,9 @@ def dummy_run(self,
110110
with_prefill: bool = False,
111111
skip_attn: bool = False,
112112
num_reqs: int = 0,
113-
num_tokens_across_dp=None) -> None:
113+
num_tokens_across_dp=None,
114+
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
115+
batch_descriptor=None) -> None:
114116
if not self.torchair_graph_enabled:
115117
# TODO: adapt enable_dbo later
116118
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -152,7 +154,9 @@ def dummy_run(self,
152154
reserved_mc2_mask=self.runner.reserved_mc2_mask,
153155
moe_comm_type=moe_comm_type,
154156
in_profile_run=self.runner.in_profile_run,
155-
num_actual_tokens=0):
157+
num_actual_tokens=0,
158+
aclgraph_runtime_mode=aclgraph_runtime_mode,
159+
batch_descriptor=batch_descriptor):
156160
if is_running_torchair:
157161
assert attn_metadata is not None
158162
torch._dynamo.mark_static(input_ids)
@@ -446,6 +450,7 @@ def _propose(
446450
reserved_mc2_mask=self.runner.reserved_mc2_mask,
447451
moe_comm_type=moe_comm_type,
448452
aclgraph_runtime_mode=aclgraph_runtime_mode,
453+
batch_descriptor=batch_descriptor,
449454
in_profile_run=self.runner.in_profile_run,
450455
num_actual_tokens=num_tokens):
451456
with ProfileExecuteDuration().capture_async('mtp_forward'):

vllm_ascend/worker/model_runner_v1.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2484,7 +2484,9 @@ def dummy_compute_logits(hidden_states):
24842484
with_prefill=with_prefill,
24852485
skip_attn=True,
24862486
num_reqs=num_reqs,
2487-
num_tokens_across_dp=num_tokens_across_dp)
2487+
num_tokens_across_dp=num_tokens_across_dp,
2488+
aclgraph_runtime_mode=aclgraph_runtime_mode,
2489+
batch_descriptor=batch_descriptor)
24882490
if need_dummy_logits:
24892491
dummy_compute_logits(hidden_states)
24902492
if self.in_profile_run and self.dynamic_eplb:

0 commit comments

Comments
 (0)