Skip to content

Commit 0bafb59

Browse files
rahul-tuliyiliu30
authored andcommitted
Add: SupportsEagle3 interface for explicit EAGLE3 support (vllm-project#22642)
Signed-off-by: Rahul Tuli <[email protected]>
1 parent 8b2cfe3 commit 0bafb59

File tree

5 files changed

+81
-8
lines changed

5 files changed

+81
-8
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,20 @@
33
import pytest
44
import torch
55

6+
from vllm.model_executor.models.interfaces import supports_eagle3
7+
68

79
@pytest.mark.parametrize(
810
"model_path",
911
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
10-
def test_llama(vllm_runner, example_prompts, model_path):
12+
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
13+
# Set environment variable for V1 engine serialization
14+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
15+
1116
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
17+
eagle3_supported = vllm_model.apply_model(supports_eagle3)
18+
assert eagle3_supported
19+
1220
vllm_outputs = vllm_model.generate_greedy(example_prompts,
1321
max_tokens=20)
1422
print(vllm_outputs)
@@ -18,8 +26,14 @@ def test_llama(vllm_runner, example_prompts, model_path):
1826
@pytest.mark.parametrize(
1927
"model_path",
2028
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
21-
def test_qwen(vllm_runner, example_prompts, model_path):
29+
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
30+
# Set environment variable for V1 engine serialization
31+
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
32+
2233
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
34+
eagle3_supported = vllm_model.apply_model(supports_eagle3)
35+
assert eagle3_supported
36+
2337
vllm_outputs = vllm_model.generate_greedy(example_prompts,
2438
max_tokens=20)
2539
print(vllm_outputs)

vllm/model_executor/models/interfaces.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,3 +823,56 @@ def supports_v0_only(
823823
model: Union[type[object], object],
824824
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
825825
return getattr(model, "supports_v0_only", False)
826+
827+
828+
@runtime_checkable
829+
class SupportsEagle3(Protocol):
830+
"""The interface required for models that support
831+
EAGLE3 speculative decoding."""
832+
833+
supports_eagle3: ClassVar[Literal[True]] = True
834+
"""
835+
A flag that indicates this model supports EAGLE3
836+
speculative decoding.
837+
838+
Note:
839+
There is no need to redefine this flag if this class is in the
840+
MRO of your model class.
841+
"""
842+
843+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
844+
"""
845+
Set which layers should output auxiliary
846+
hidden states for EAGLE3.
847+
848+
Args:
849+
layers: Tuple of layer indices that should output auxiliary
850+
hidden states.
851+
"""
852+
...
853+
854+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
855+
"""
856+
Get the layer indices that should output auxiliary hidden states
857+
for EAGLE3.
858+
859+
Returns:
860+
Tuple of layer indices for auxiliary hidden state outputs.
861+
"""
862+
...
863+
864+
865+
@overload
866+
def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]:
867+
...
868+
869+
870+
@overload
871+
def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]:
872+
...
873+
874+
875+
def supports_eagle3(
876+
model: Union[type[object], object],
877+
) -> Union[TypeIs[type[SupportsEagle3]], TypeIs[SupportsEagle3]]:
878+
return isinstance(model, SupportsEagle3)

vllm/model_executor/models/llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from vllm.model_executor.sampling_metadata import SamplingMetadata
5050
from vllm.sequence import IntermediateTensors
5151

52-
from .interfaces import SupportsLoRA, SupportsPP
52+
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
5353
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
5454
is_pp_missing_parameter,
5555
make_empty_intermediate_tensors_factory, make_layers,
@@ -463,7 +463,7 @@ def load_weights(self, weights: Iterable[tuple[str,
463463
return loaded_params
464464

465465

466-
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
466+
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
467467
packed_modules_mapping = {
468468
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
469469
"gate_up_proj": ["gate_proj", "up_proj"]

vllm/model_executor/models/qwen3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from vllm.model_executor.sampling_metadata import SamplingMetadata
4545
from vllm.sequence import IntermediateTensors
4646

47-
from .interfaces import SupportsLoRA, SupportsPP
47+
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
4848
from .qwen2 import Qwen2MLP as Qwen3MLP
4949
from .qwen2 import Qwen2Model
5050
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
@@ -261,7 +261,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
261261
decoder_layer_type=Qwen3DecoderLayer)
262262

263263

264-
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
264+
class Qwen3ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
265265
packed_modules_mapping = {
266266
"qkv_proj": [
267267
"q_proj",

vllm/v1/worker/gpu_model_runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
3636
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
3737
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
38+
supports_eagle3,
3839
supports_transcription)
3940
from vllm.model_executor.models.interfaces_base import (
4041
VllmModelForPooling, is_pooling_model, is_text_generation_model)
@@ -1981,8 +1982,13 @@ def load_model(self, eep_scale_up: bool = False) -> None:
19811982
logger.info("Loading drafter model...")
19821983
self.drafter.load_model(self.model)
19831984
if self.use_aux_hidden_state_outputs:
1984-
self.model.set_aux_hidden_state_layers(
1985-
self.model.get_eagle3_aux_hidden_state_layers())
1985+
if supports_eagle3(self.model):
1986+
self.model.set_aux_hidden_state_layers(
1987+
self.model.get_eagle3_aux_hidden_state_layers())
1988+
else:
1989+
raise RuntimeError(
1990+
"Model does not support EAGLE3 interface but "
1991+
"aux_hidden_state_outputs was requested")
19861992
time_after_load = time.perf_counter()
19871993
self.model_memory_usage = m.consumed_memory
19881994
logger.info("Model loading took %.4f GiB and %.6f seconds",

0 commit comments

Comments
 (0)