Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm_ascend/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -177,6 +178,7 @@ def compute_logits(
return logits


@support_torch_compile
class CustomDeepSeekMTP(DeepSeekMTP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/spec_decode/eagle_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import logger
Expand Down Expand Up @@ -115,7 +115,9 @@ def dummy_run(self,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
Expand Down
6 changes: 4 additions & 2 deletions vllm_ascend/spec_decode/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional

import torch
from vllm.config import VllmConfig
from vllm.config import CUDAGraphMode, VllmConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
Expand Down Expand Up @@ -33,7 +33,9 @@ def dummy_run(self,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
"""Called by dummy_run in modle_runner"""
raise NotImplementedError

Expand Down
13 changes: 9 additions & 4 deletions vllm_ascend/spec_decode/mtp_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import torchair
from torchair import patch_for_hcom
from vllm.attention.layer import Attention
from vllm.config import (VllmConfig, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.model_executor.model_loader import get_model_loader
from vllm.model_executor.model_loader.utils import (
Expand Down Expand Up @@ -110,7 +110,9 @@ def dummy_run(self,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp=None) -> None:
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None) -> None:
if not self.torchair_graph_enabled:
# TODO: adapt enable_dbo later
(num_tokens, num_tokens_across_dp, with_prefill,
Expand Down Expand Up @@ -152,7 +154,9 @@ def dummy_run(self,
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=0):
num_actual_tokens=0,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor):
if is_running_torchair:
assert attn_metadata is not None
torch._dynamo.mark_static(input_ids)
Expand Down Expand Up @@ -446,6 +450,7 @@ def _propose(
reserved_mc2_mask=self.runner.reserved_mc2_mask,
moe_comm_type=moe_comm_type,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor,
in_profile_run=self.runner.in_profile_run,
num_actual_tokens=num_tokens):
with ProfileExecuteDuration().capture_async('mtp_forward'):
Expand Down
5 changes: 4 additions & 1 deletion vllm_ascend/spec_decode/ngram_proposer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
from vllm.config import CUDAGraphMode
from vllm.v1.spec_decode.ngram_proposer import \
NgramProposer as VllmNgramProposer

Expand All @@ -23,7 +24,9 @@ def dummy_run(self,
with_prefill=None,
skip_attn=None,
num_reqs=None,
num_tokens_across_dp=None):
num_tokens_across_dp=None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
pass

def generate_token_ids(self,
Expand Down
4 changes: 3 additions & 1 deletion vllm_ascend/worker/model_runner_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2484,7 +2484,9 @@ def dummy_compute_logits(hidden_states):
with_prefill=with_prefill,
skip_attn=True,
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp)
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
if self.in_profile_run and self.dynamic_eplb:
Expand Down
Loading