diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 0c4f173946..4ccc36d9fb 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -23,6 +23,7 @@ import torch.nn as nn from transformers import PretrainedConfig from vllm.attention.backends.abstract import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) from vllm.model_executor.layers.layernorm import RMSNorm @@ -177,6 +178,7 @@ def compute_logits( return logits +@support_torch_compile class CustomDeepSeekMTP(DeepSeekMTP): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index d14dc6d2a4..0d8d5cd2b0 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from vllm.attention.layer import Attention -from vllm.config import (CompilationLevel, VllmConfig, +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group from vllm.logger import logger @@ -115,7 +115,9 @@ def dummy_run(self, with_prefill: bool = False, skip_attn: bool = False, num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None): + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): moe_comm_type = self.runner._select_moe_comm_method( num_tokens, with_prefill) with set_ascend_forward_context(None, diff --git a/vllm_ascend/spec_decode/interface.py b/vllm_ascend/spec_decode/interface.py index 0efe93de33..3f0a36b13c 100644 --- a/vllm_ascend/spec_decode/interface.py +++ b/vllm_ascend/spec_decode/interface.py @@ -2,7 +2,7 @@ from typing import Optional import torch -from vllm.config import VllmConfig +from vllm.config import CUDAGraphMode, VllmConfig from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -33,7 +33,9 @@ def dummy_run(self, with_prefill: bool = False, skip_attn: bool = False, num_reqs: int = 0, - num_tokens_across_dp: Optional[torch.Tensor] = None): + num_tokens_across_dp: Optional[torch.Tensor] = None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): """Called by dummy_run in modle_runner""" raise NotImplementedError diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index ed4e8870cf..208a154321 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -5,8 +5,8 @@ import torchair from torchair import patch_for_hcom from vllm.attention.layer import Attention -from vllm.config import (VllmConfig, get_layers_from_vllm_config, - set_current_vllm_config) +from vllm.config import (CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, set_current_vllm_config) from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( @@ -110,7 +110,9 @@ def dummy_run(self, with_prefill: bool = False, skip_attn: bool = False, num_reqs: int = 0, - num_tokens_across_dp=None) -> None: + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None) -> None: if not self.torchair_graph_enabled: # TODO: adapt enable_dbo later (num_tokens, num_tokens_across_dp, with_prefill, @@ -152,7 +154,9 @@ def dummy_run(self, reserved_mc2_mask=self.runner.reserved_mc2_mask, moe_comm_type=moe_comm_type, in_profile_run=self.runner.in_profile_run, - num_actual_tokens=0): + num_actual_tokens=0, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor): if is_running_torchair: assert attn_metadata is not None torch._dynamo.mark_static(input_ids) @@ -446,6 +450,7 @@ def _propose( reserved_mc2_mask=self.runner.reserved_mc2_mask, moe_comm_type=moe_comm_type, aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor, in_profile_run=self.runner.in_profile_run, num_actual_tokens=num_tokens): with ProfileExecuteDuration().capture_async('mtp_forward'): diff --git a/vllm_ascend/spec_decode/ngram_proposer.py b/vllm_ascend/spec_decode/ngram_proposer.py index 9999f1f36d..34b5b95408 100644 --- a/vllm_ascend/spec_decode/ngram_proposer.py +++ b/vllm_ascend/spec_decode/ngram_proposer.py @@ -1,4 +1,5 @@ import torch +from vllm.config import CUDAGraphMode from vllm.v1.spec_decode.ngram_proposer import \ NgramProposer as VllmNgramProposer @@ -23,7 +24,9 @@ def dummy_run(self, with_prefill=None, skip_attn=None, num_reqs=None, - num_tokens_across_dp=None): + num_tokens_across_dp=None, + aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, + batch_descriptor=None): pass def generate_token_ids(self, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 373a73e297..1f43bc93c1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2484,7 +2484,9 @@ def dummy_compute_logits(hidden_states): with_prefill=with_prefill, skip_attn=True, num_reqs=num_reqs, - num_tokens_across_dp=num_tokens_across_dp) + num_tokens_across_dp=num_tokens_across_dp, + aclgraph_runtime_mode=aclgraph_runtime_mode, + batch_descriptor=batch_descriptor) if need_dummy_logits: dummy_compute_logits(hidden_states) if self.in_profile_run and self.dynamic_eplb: