Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions vllm_ascend/spec_decode/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
#
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.spec_decode.ngram_proposer import NgramProposer


def get_spec_decode_method(method, vllm_config, device, runner):
if method == "ngram":
return NgramProposer(vllm_config, device, runner)
elif method in ["eagle", "eagle3"]:
return EagleProposer(vllm_config, device, runner)
elif method == 'deepseek_mtp':
return MtpProposer(vllm_config, device, runner)
else:
raise ValueError("Unknown speculative decoding method: "
f"{method}")
643 changes: 643 additions & 0 deletions vllm_ascend/spec_decode/eagle_proposer.py

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions vllm_ascend/spec_decode/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import enum
from typing import Optional

import torch
from vllm.config import VllmConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata


class SpecDcodeType(enum.Enum):
NGRAM = 0
EAGLE = 1
EAGLE3 = 2
MTP = 4


class Proposer:

def __init__(self,
vllm_config: VllmConfig,
device: torch.device = None,
runner=None):
pass

def load_model(self, model):
"""Called by load_model in model_runner"""
raise NotImplementedError

@torch.inference_mode()
def dummy_run(self,
num_tokens: int,
with_prefill: bool = False,
skip_attn: bool = False,
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None):
"""Called by dummy_run in modle_runner"""
raise NotImplementedError

def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
attn_metadata=None,
aux_hidden_states: torch.Tensor = None):
"""Called by execute_model in model_runner"""
raise NotImplementedError
Loading
Loading