Skip to content

Commit 3d7363e

Browse files
authored
[Config] add "qwen" as a native eagle3 target supported model (#22333)
Signed-off-by: lechen <[email protected]> Signed-off-by: LeChen <[email protected]>
1 parent 0c5254b commit 3d7363e

File tree

5 files changed

+30
-27
lines changed

5 files changed

+30
-27
lines changed

tests/models/registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,10 @@ def check_available_online(
525525
trust_remote_code=True,
526526
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
527527
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
528+
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
529+
trust_remote_code=True,
530+
speculative_model="AngelSlim/Qwen3-8B_eagle3",
531+
tokenizer="Qwen/Qwen3-8B"),
528532
"EagleLlama4ForCausalLM": _HfExamplesInfo(
529533
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
530534
trust_remote_code=True,

tests/v1/e2e/test_spec_decode.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,27 @@ def test_ngram_correctness(
125125
cleanup_dist_env_and_memory()
126126

127127

128-
@pytest.mark.parametrize(
129-
["model_setup", "mm_enabled"], [
130-
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
131-
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
132-
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
133-
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
134-
pytest.param(
135-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
136-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
137-
False,
138-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
139-
pytest.param(
140-
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
141-
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
142-
True,
143-
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
144-
],
145-
ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"])
128+
@pytest.mark.parametrize(["model_setup", "mm_enabled"], [
129+
(("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False),
130+
(("eagle", "meta-llama/Llama-3.1-8B-Instruct",
131+
"yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False),
132+
(("eagle3", "meta-llama/Llama-3.1-8B-Instruct",
133+
"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False),
134+
pytest.param(
135+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
136+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
137+
False,
138+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
139+
pytest.param(
140+
("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct",
141+
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4),
142+
True,
143+
marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")),
144+
],
145+
ids=[
146+
"qwen3_eagle3", "llama3_eagle", "llama3_eagle3",
147+
"llama4_eagle", "llama4_eagle_mm"
148+
])
146149
@pytest.mark.parametrize("attn_backend",
147150
get_attn_backend_list_based_on_platform())
148151
def test_eagle_correctness(

vllm/config/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2852,13 +2852,7 @@ def _verify_args(self) -> Self:
28522852
"speculative decoding is > 1, but got "
28532853
f"{self.disable_by_batch_size=}")
28542854

2855-
from vllm.transformers_utils.configs import SpeculatorsConfig
2856-
2857-
eagle3_target_supported = ["llama"]
2858-
if self.draft_model_config and isinstance(
2859-
self.draft_model_config.hf_config, SpeculatorsConfig):
2860-
eagle3_target_supported.append("qwen")
2861-
2855+
eagle3_target_supported = ["llama", "qwen"]
28622856
if self.method == "eagle3" and self.target_model_config and not any(
28632857
supported_model in
28642858
self.target_model_config.hf_text_config.model_type

vllm/model_executor/models/registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@
259259
"EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"),
260260
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
261261
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
262+
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
262263
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
263264
"Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"),
264265
"MedusaModel": ("medusa", "Medusa"),

vllm/transformers_utils/configs/eagle.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self,
4545

4646
# Eagle model name should follow naming convention of
4747
# LlamaForCausalLM -> EagleLlamaForCausalLM
48+
# LlamaForCausalLM -> Eagle3LlamaForCausalLM / LlamaForCausalLMEagle3
4849
if method == "eagle":
4950
assert self.model is not None, \
5051
"model should not be None when method is eagle"
@@ -56,8 +57,8 @@ def __init__(self,
5657
assert self.model is not None, \
5758
"model should not be None when method is eagle3"
5859
kwargs["architectures"] = [
59-
f"Eagle3{arch}" if not arch.startswith("Eagle3") \
60-
else arch for arch in self.model.architectures
60+
arch if arch.startswith("Eagle3") or arch.endswith("Eagle3")
61+
else f"Eagle3{arch}" for arch in self.model.architectures
6162
]
6263
else:
6364
raise ValueError(f"Invalid method {method}. \

0 commit comments

Comments
 (0)