Skip to content

Commit 9f9c38c

Browse files
authored
[Speculators][Speculative Decoding] Add Qwen Eagle3 Support (#21835)
Signed-off-by: Dipika Sikka <[email protected]>
1 parent a65f46b commit 9f9c38c

File tree

4 files changed

+46
-11
lines changed

4 files changed

+46
-11
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,21 @@
66

77
@pytest.mark.parametrize(
88
"model_path",
9-
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717"),
10-
("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
9+
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
1110
def test_llama(vllm_runner, example_prompts, model_path):
1211
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
1312
vllm_outputs = vllm_model.generate_greedy(example_prompts,
1413
max_tokens=20)
1514
print(vllm_outputs)
1615
assert vllm_outputs
16+
17+
18+
@pytest.mark.parametrize(
19+
"model_path",
20+
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
21+
def test_qwen(vllm_runner, example_prompts, model_path):
22+
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
23+
vllm_outputs = vllm_model.generate_greedy(example_prompts,
24+
max_tokens=20)
25+
print(vllm_outputs)
26+
assert vllm_outputs

vllm/config.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3175,10 +3175,19 @@ def _verify_args(self) -> Self:
31753175
"speculative decoding is > 1, but got "
31763176
f"{self.disable_by_batch_size=}")
31773177

3178-
if self.method == "eagle3" and self.target_model_config and \
3179-
"llama" not in self.target_model_config.hf_text_config.model_type:
3178+
from vllm.transformers_utils.configs import SpeculatorsConfig
3179+
3180+
eagle3_target_supported = ["llama"]
3181+
if self.draft_model_config and isinstance(
3182+
self.draft_model_config.hf_config, SpeculatorsConfig):
3183+
eagle3_target_supported.append("qwen")
3184+
3185+
if self.method == "eagle3" and self.target_model_config and not any(
3186+
supported_model in
3187+
self.target_model_config.hf_text_config.model_type
3188+
for supported_model in eagle3_target_supported):
31803189
raise ValueError(
3181-
"Eagle3 is only supported for Llama models. "
3190+
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
31823191
f"Got {self.target_model_config.hf_text_config.model_type=}")
31833192

31843193
return self

vllm/model_executor/models/qwen2.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,8 @@ def __init__(self,
330330
else:
331331
self.norm = PPMissingLayer()
332332

333+
self.aux_hidden_state_layers: tuple[int] = tuple()
334+
333335
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
334336
return self.embed_tokens(input_ids)
335337

@@ -350,18 +352,25 @@ def forward(
350352
assert intermediate_tensors is not None
351353
hidden_states = intermediate_tensors["hidden_states"]
352354
residual = intermediate_tensors["residual"]
353-
for layer in self.layers[self.start_layer:self.end_layer]:
354-
hidden_states, residual = layer(
355-
positions,
356-
hidden_states,
357-
residual,
358-
)
355+
356+
aux_hidden_states = []
357+
for idx, layer in enumerate(
358+
self.layers[self.start_layer:self.end_layer]):
359+
if idx in self.aux_hidden_state_layers:
360+
aux_hidden_states.append(hidden_states + residual)
361+
hidden_states, residual = layer(positions, hidden_states, residual)
362+
359363
if not get_pp_group().is_last_rank:
360364
return IntermediateTensors({
361365
"hidden_states": hidden_states,
362366
"residual": residual
363367
})
368+
364369
hidden_states, _ = self.norm(hidden_states, residual)
370+
371+
if len(aux_hidden_states) > 0:
372+
return hidden_states, aux_hidden_states
373+
365374
return hidden_states
366375

367376
def load_weights(self, weights: Iterable[tuple[str,

vllm/model_executor/models/qwen3.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
288288
self.make_empty_intermediate_tensors = (
289289
self.model.make_empty_intermediate_tensors)
290290

291+
def set_aux_hidden_state_layers(self, layers: tuple[int]) -> None:
292+
self.model.aux_hidden_state_layers = layers
293+
294+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int]:
295+
num_layers = len(self.model.layers)
296+
return (2, num_layers // 2, num_layers - 3)
297+
291298
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
292299
return self.model.get_input_embeddings(input_ids)
293300

0 commit comments

Comments
 (0)