Skip to content

Commit c5b00ce

Browse files
committed
[Refactor] refactor spec decode
Signed-off-by: wangxiyuan <[email protected]>
1 parent 7e494e9 commit c5b00ce

File tree

7 files changed

+1018
-930
lines changed

7 files changed

+1018
-930
lines changed

vllm_ascend/spec_decode/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
18+
#
19+
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
20+
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
21+
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer
22+
23+
24+
def get_spec_decode_method(method, vllm_config, device, runner):
25+
if method == "ngram":
26+
return NgramProposer(vllm_config, device, runner)
27+
elif method in ["eagle", "eagle3"]:
28+
return EagleProposer(vllm_config, device, runner)
29+
elif method == 'deepseek_mtp':
30+
return MtpProposer(vllm_config, device, runner)
31+
else:
32+
raise ValueError("Unknown speculative decoding method: "
33+
f"{method}")

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 643 additions & 0 deletions
Large diffs are not rendered by default.

vllm_ascend/spec_decode/interface.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import enum
2+
from typing import Optional
3+
4+
import torch
5+
from vllm.config import VllmConfig
6+
from vllm.v1.core.sched.output import SchedulerOutput
7+
from vllm.v1.sample.metadata import SamplingMetadata
8+
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
9+
10+
11+
class SpecDcodeType(enum.Enum):
12+
NGRAM = 0
13+
EAGLE = 1
14+
EAGLE3 = 2
15+
MTP = 4
16+
17+
18+
class Proposer:
19+
20+
def __init__(self,
21+
vllm_config: VllmConfig,
22+
device: torch.device = None,
23+
runner=None):
24+
pass
25+
26+
def load_model(self, model):
27+
"""Called by load_model in model_runner"""
28+
raise NotImplementedError
29+
30+
@torch.inference_mode()
31+
def dummy_run(self,
32+
num_tokens: int,
33+
with_prefill: bool = False,
34+
skip_attn: bool = False,
35+
num_reqs: int = 0,
36+
num_tokens_across_dp: Optional[torch.Tensor] = None):
37+
"""Called by dummy_run in modle_runner"""
38+
raise NotImplementedError
39+
40+
def generate_token_ids(self,
41+
valid_sampled_token_ids: list[list[int]],
42+
sampling_metadata: SamplingMetadata = None,
43+
scheduler_output: SchedulerOutput = None,
44+
spec_decode_metadata: SpecDecodeMetadata = None,
45+
positions: torch.Tensor = None,
46+
num_scheduled_tokens: int = 0,
47+
hidden_states: torch.Tensor = None,
48+
attn_metadata=None,
49+
aux_hidden_states: torch.Tensor = None):
50+
"""Called by execute_model in model_runner"""
51+
raise NotImplementedError

0 commit comments

Comments
 (0)