diff --git a/tests/models/registry.py b/tests/models/registry.py index 09d62413feed..2879f309edaf 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -523,6 +523,10 @@ def check_available_online( trust_remote_code=True, speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", tokenizer="meta-llama/Llama-3.1-8B-Instruct"), + "LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501 + trust_remote_code=True, + speculative_model="AngelSlim/Qwen3-8B_eagle3", + tokenizer="Qwen/Qwen3-8B"), "EagleLlama4ForCausalLM": _HfExamplesInfo( "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", trust_remote_code=True, diff --git a/tests/v1/e2e/test_spec_decode.py b/tests/v1/e2e/test_spec_decode.py index 4950faf826b8..cd383b58db2a 100644 --- a/tests/v1/e2e/test_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -125,24 +125,27 @@ def test_ngram_correctness( cleanup_dist_env_and_memory() -@pytest.mark.parametrize( - ["model_setup", "mm_enabled"], [ - (("eagle", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), - (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", - "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - False, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - pytest.param( - ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", - "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), - True, - marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), - ], - ids=["llama3_eagle", "llama3_eagle3", "llama4_eagle", "llama4_eagle_mm"]) +@pytest.mark.parametrize(["model_setup", "mm_enabled"], [ + (("eagle3", "Qwen/Qwen3-8B", "AngelSlim/Qwen3-8B_eagle3", 1), False), + (("eagle", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE-LLaMA3.1-Instruct-8B", 1), False), + (("eagle3", "meta-llama/Llama-3.1-8B-Instruct", + "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", 1), False), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + False, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), + pytest.param( + ("eagle", "meta-llama/Llama-4-Scout-17B-16E-Instruct", + "morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct", 4), + True, + marks=pytest.mark.skip(reason="Skipping due to CI OOM issues")), +], + ids=[ + "qwen3_eagle3", "llama3_eagle", "llama3_eagle3", + "llama4_eagle", "llama4_eagle_mm" + ]) @pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform()) def test_eagle_correctness( diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 7efab23f144a..b2826de93d49 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2852,13 +2852,7 @@ def _verify_args(self) -> Self: "speculative decoding is > 1, but got " f"{self.disable_by_batch_size=}") - from vllm.transformers_utils.configs import SpeculatorsConfig - - eagle3_target_supported = ["llama"] - if self.draft_model_config and isinstance( - self.draft_model_config.hf_config, SpeculatorsConfig): - eagle3_target_supported.append("qwen") - + eagle3_target_supported = ["llama", "qwen"] if self.method == "eagle3" and self.target_model_config and not any( supported_model in self.target_model_config.hf_text_config.model_type diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 4aa958ecdcd1..4512bec63b0c 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -259,6 +259,7 @@ "EagleLlama4ForCausalLM": ("llama4_eagle", "EagleLlama4ForCausalLM"), "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "Glm4MoeMTPModel": ("glm4_moe_mtp", "Glm4MoeMTP"), "MedusaModel": ("medusa", "Medusa"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index 5445a333c493..01217eb19126 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -45,6 +45,7 @@ def __init__(self, # Eagle model name should follow naming convention of # LlamaForCausalLM -> EagleLlamaForCausalLM + # LlamaForCausalLM -> Eagle3LlamaForCausalLM / LlamaForCausalLMEagle3 if method == "eagle": assert self.model is not None, \ "model should not be None when method is eagle" @@ -56,8 +57,8 @@ def __init__(self, assert self.model is not None, \ "model should not be None when method is eagle3" kwargs["architectures"] = [ - f"Eagle3{arch}" if not arch.startswith("Eagle3") \ - else arch for arch in self.model.architectures + arch if arch.startswith("Eagle3") or arch.endswith("Eagle3") + else f"Eagle3{arch}" for arch in self.model.architectures ] else: raise ValueError(f"Invalid method {method}. \