Skip to content

Commit 3e7ee1f

Browse files
committed
[Refactor] refactor spec decode
Signed-off-by: wangxiyuan <[email protected]>
1 parent 103654c commit 3e7ee1f

File tree

7 files changed

+571
-501
lines changed

7 files changed

+571
-501
lines changed

vllm_ascend/spec_decode/__init__.py

Whitespace-only changes.

vllm_ascend/worker/eagle_proposer_v1.py renamed to vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 222 additions & 134 deletions
Large diffs are not rendered by default.

vllm_ascend/spec_decode/interface.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from typing import Optional, Union
2+
3+
import torch
4+
from vllm.config import VllmConfig
5+
from vllm.v1.core.sched.output import SchedulerOutput
6+
from vllm.v1.sample.metadata import SamplingMetadata
7+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
8+
9+
from vllm_ascend.attention.attention_v1 import AscendMetadata
10+
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
11+
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
12+
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
13+
14+
15+
class Proposer:
16+
17+
def __init__(self,
18+
vllm_config: VllmConfig,
19+
device: torch.device = None,
20+
runner: NPUModelRunner = None):
21+
pass
22+
23+
def load_model(self, model):
24+
"""Called by load_model in model_runner"""
25+
raise NotImplementedError
26+
27+
@torch.inference_mode()
28+
def dummy_run(self,
29+
num_tokens: int,
30+
with_prefill: bool = None,
31+
skip_attn: bool = None,
32+
num_reqs: int = None,
33+
num_tokens_across_dp: Optional[torch.Tensor] = None):
34+
"""Called by dummy_run in modle_runner"""
35+
raise NotImplementedError
36+
37+
def generate_token_ids(self,
38+
valid_sampled_token_ids: list[list[int]],
39+
sampling_metadata: SamplingMetadata = None,
40+
scheduler_output: SchedulerOutput = None,
41+
spec_decode_metadata: SpecDecodeMetadata = None,
42+
positions: torch.Tensor = None,
43+
num_scheduled_tokens: int = None,
44+
hidden_states: torch.Tensor = None,
45+
attn_metadata: Union[AscendMetadata,
46+
AscendMLAMetadata,
47+
AscendTorchairMetadata] = None,
48+
aux_hidden_states: torch.Tensor = None,
49+
attn_metadata_builder=None):
50+
"""Called by execute_model in model_runner"""
51+
raise NotImplementedError

0 commit comments

Comments
 (0)