|
30 | 30 | from vllm.v1.attention.backends.utils import CommonAttentionMetadata |
31 | 31 | from vllm.v1.core.sched.output import SchedulerOutput |
32 | 32 | from vllm.v1.sample.metadata import SamplingMetadata |
33 | | -from vllm.v1.spec_decode.eagle import SpecDecodeBaseProposer as VllmSpecDecodeBaseProposer |
| 33 | +from vllm.v1.spec_decode.eagle import EagleProposer |
34 | 34 | from vllm.v1.spec_decode.metadata import SpecDecodeMetadata |
35 | 35 | from vllm.v1.spec_decode.utils import ( |
36 | 36 | PADDING_SLOT_ID, |
@@ -85,11 +85,20 @@ def split_inputs_tp_to_sp(hidden_states, out): |
85 | 85 | return out[:padded_num_tokens_per_rank] |
86 | 86 |
|
87 | 87 |
|
88 | | -class SpecDecodeBaseProposer(VllmSpecDecodeBaseProposer): |
| 88 | +class SpecDecodeBaseProposer(EagleProposer): |
89 | 89 | _runnable: ACLGraphWrapper | Callable |
90 | 90 |
|
91 | 91 | def __init__(self, vllm_config: VllmConfig, device: torch.device, pass_hidden_states_to_model: bool, runner=None): |
92 | | - super().__init__(vllm_config, device, pass_hidden_states_to_model=pass_hidden_states_to_model, runner=runner) |
| 92 | + super().__init__(vllm_config, device, runner) |
| 93 | + |
| 94 | + # EagleProposer.__init__ hardcodes pass_hidden_states_to_model=True, so |
| 95 | + # the derived values are incorrect when pass_hidden_states_to_model=False |
| 96 | + # (e.g. AscendDraftModelProposer). Recompute them with the correct value. |
| 97 | + self.pass_hidden_states_to_model = pass_hidden_states_to_model |
| 98 | + self.net_num_new_slots_per_request = self.extra_slots_per_request - ( |
| 99 | + 1 if self.pass_hidden_states_to_model else 0 |
| 100 | + ) |
| 101 | + self.needs_extra_input_slots = self.net_num_new_slots_per_request > 0 |
93 | 102 |
|
94 | 103 | self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling |
95 | 104 | self.decode_threshold = 1 + self.num_speculative_tokens |
|
0 commit comments