Skip to content

Commit 7b68b19

Browse files
committed
[Refactor] refactor spec decode
Signed-off-by: wangxiyuan <[email protected]>
1 parent 103654c commit 7b68b19

File tree

7 files changed

+571
-501
lines changed

7 files changed

+571
-501
lines changed

vllm_ascend/spec_decode/__init__.py

Whitespace-only changes.

vllm_ascend/worker/eagle_proposer_v1.py renamed to vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 222 additions & 134 deletions
Large diffs are not rendered by default.

vllm_ascend/spec_decode/interface.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from typing import Optional, Union
2+
3+
import torch
4+
from vllm.config import VllmConfig
5+
from vllm.v1.core.sched.output import SchedulerOutput
6+
from vllm.v1.sample.metadata import SamplingMetadata
7+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
8+
9+
from vllm_ascend.attention.attention_v1 import AscendMetadata
10+
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
11+
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
12+
13+
14+
class Proposer:
15+
16+
def __init__(self,
17+
vllm_config: VllmConfig,
18+
device: torch.device = None,
19+
runner=None):
20+
pass
21+
22+
def load_model(self, model):
23+
"""Called by load_model in model_runner"""
24+
raise NotImplementedError
25+
26+
@torch.inference_mode()
27+
def dummy_run(self,
28+
num_tokens: int,
29+
with_prefill: bool = False,
30+
skip_attn: bool = False,
31+
num_reqs: int = 0,
32+
num_tokens_across_dp: Optional[torch.Tensor] = None):
33+
"""Called by dummy_run in modle_runner"""
34+
raise NotImplementedError
35+
36+
def generate_token_ids(
37+
self,
38+
valid_sampled_token_ids: list[list[int]],
39+
sampling_metadata: SamplingMetadata = None,
40+
scheduler_output: SchedulerOutput = None,
41+
spec_decode_metadata: SpecDecodeMetadata = None,
42+
positions: torch.Tensor = None,
43+
num_scheduled_tokens: int = 0,
44+
hidden_states: torch.Tensor = None,
45+
attn_metadata: Optional[Union[AscendMetadata, AscendMLAMetadata,
46+
AscendTorchairMetadata]] = None,
47+
aux_hidden_states: torch.Tensor = None,
48+
attn_metadata_builder=None):
49+
"""Called by execute_model in model_runner"""
50+
raise NotImplementedError

0 commit comments

Comments
 (0)