Skip to content

Commit ad7469d

Browse files
committed
fix
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 380d996 commit ad7469d

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

vllm_ascend/spec_decode/eagle_proposer.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
3131
from vllm.v1.core.sched.output import SchedulerOutput
3232
from vllm.v1.sample.metadata import SamplingMetadata
33-
from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer as VllmSpecDecodeBaseProposer
33+
from vllm.v1.spec_decode.eagle import EagleProposer
3434
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
3535
from vllm.v1.spec_decode.utils import (
3636
PADDING_SLOT_ID,
@@ -85,11 +85,20 @@ def split_inputs_tp_to_sp(hidden_states, out):
8585
return out[:padded_num_tokens_per_rank]
8686

8787

88-
class SpecDecodeBaseProposer(VllmSpecDecodeBaseProposer):
88+
class SpecDecodeBaseProposer(EagleProposer):
8989
_runnable: ACLGraphWrapper | Callable
9090

9191
def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, runner=None):
92-
super().__init__(vllm_config, device, pass_hidden_states_to_model=pass_hidden_states_to_model, runner=runner)
92+
super().__init__(vllm_config, device, runner)
93+
94+
# EagleProposer.__init__ hardcodes pass_hidden_states_to_model=True, so
95+
# the derived values are incorrect when pass_hidden_states_to_model=False
96+
# (e.g. AscendDraftModelProposer). Recompute them with the correct value.
97+
self.pass_hidden_states_to_model = pass_hidden_states_to_model
98+
self.net_num_new_slots_per_request = self.extra_slots_per_request - (
99+
1 if self.pass_hidden_states_to_model else 0
100+
)
101+
self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0
93102

94103
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
95104
self.decode_threshold = 1 + self.num_speculative_tokens

0 commit comments

Comments
 (0)