Skip to content

Commit 68201ac

Browse files
committed
[WIP]mtp aclgraph support
Signed-off-by: anon189Ty <[email protected]>
1 parent f2d8493 commit 68201ac

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

vllm_ascend/models/deepseek_mtp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.nn as nn
2424
from transformers import PretrainedConfig
2525
from vllm.attention.backends.abstract import AttentionMetadata
26+
from vllm.compilation.decorators import support_torch_compile
2627
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
2728
get_current_vllm_config)
2829
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -177,6 +178,7 @@ def compute_logits(
177178
return logits
178179

179180

181+
@support_torch_compile
180182
class CustomDeepSeekMTP(DeepSeekMTP):
181183

182184
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

vllm_ascend/spec_decode/mtp_proposer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ def _propose(
446446
reserved_mc2_mask=self.runner.reserved_mc2_mask,
447447
moe_comm_type=moe_comm_type,
448448
aclgraph_runtime_mode=aclgraph_runtime_mode,
449+
batch_descriptor=batch_descriptor,
449450
in_profile_run=self.runner.in_profile_run,
450451
num_actual_tokens=num_tokens):
451452
with ProfileExecuteDuration().capture_async('mtp_forward'):

0 commit comments

Comments
 (0)