Skip to content

Commit e439c78

Browse files
authored
Add support for Eagle with separate lm-head and embed_tokens layers (#28549)
Signed-off-by: Eldar Kurtic <[email protected]>
1 parent 085a525 commit e439c78

File tree

12 files changed

+204
-63
lines changed

12 files changed

+204
-63
lines changed

tests/v1/spec_decode/test_eagle.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def test_prepare_inputs_padded():
324324
@pytest.mark.parametrize("attn_backend", get_attn_backend_list_based_on_platform())
325325
@pytest.mark.parametrize("pp_size", [1, 2])
326326
@pytest.mark.parametrize("use_distinct_embed_tokens", [True, False])
327+
@pytest.mark.parametrize("use_distinct_lm_head", [True, False])
327328
@mock.patch("vllm.v1.spec_decode.eagle.get_pp_group")
328329
@mock.patch("vllm.v1.spec_decode.eagle.get_layers_from_vllm_config")
329330
@mock.patch("vllm.v1.spec_decode.eagle.get_model")
@@ -335,6 +336,7 @@ def test_load_model(
335336
attn_backend,
336337
pp_size,
337338
use_distinct_embed_tokens,
339+
use_distinct_lm_head,
338340
monkeypatch,
339341
):
340342
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
@@ -350,12 +352,13 @@ def test_load_model(
350352

351353
# Setup draft model mock
352354
mock_model = mock.MagicMock()
355+
mock_model.model = mock.MagicMock()
356+
mock_model.has_own_embed_tokens = use_distinct_embed_tokens
353357
if use_distinct_embed_tokens:
354-
# Some models can have a different hidden size than the target model,
355-
# so we test that their embed_tokens doesn't get overwritten
356-
mock_model.model.embed_tokens.weight.shape = (131072, 2048)
357-
else:
358-
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
358+
mock_model.model.embed_tokens = mock.MagicMock()
359+
mock_model.has_own_lm_head = use_distinct_lm_head
360+
if use_distinct_lm_head:
361+
mock_model.lm_head = mock.MagicMock()
359362

360363
mock_get_model.return_value = mock_model
361364

@@ -391,15 +394,13 @@ class _TargetModelStub(LlamaForCausalLM):
391394

392395
target_model = mock.create_autospec(_TargetModelStub, instance=True)
393396
target_model.model = mock.MagicMock()
394-
target_model.model.embed_tokens.weight.shape = (131072, 4096)
397+
target_model.lm_head = mock.MagicMock()
398+
target_model.model.embed_tokens = mock.MagicMock()
395399

396400
from vllm.model_executor.models import SupportsMultiModal
397401

398402
assert not isinstance(target_model, SupportsMultiModal)
399403

400-
if method == "eagle":
401-
target_model.lm_head = mock.MagicMock()
402-
403404
# Create proposer using the helper function
404405
proposer = _create_proposer(method, num_speculative_tokens=8)
405406

@@ -409,18 +410,18 @@ class _TargetModelStub(LlamaForCausalLM):
409410
# Verify common interactions
410411
mock_get_model.assert_called_once()
411412

412-
# Verify that EAGLE models gain the lm head from the target model
413-
if method == "eagle":
414-
assert proposer.model.lm_head == target_model.lm_head
413+
# Verify that the lm head is set correctly
414+
if use_distinct_lm_head:
415+
assert proposer.model.lm_head is not target_model.lm_head
416+
else:
417+
assert proposer.model.lm_head is target_model.lm_head
415418

416419
# Verify that the embed tokens are set correctly
417420
# If pp_size is > 1, the embed tokens should be distinct
418421
if pp_size > 1 or use_distinct_embed_tokens:
419-
assert proposer.model.model.embed_tokens != target_model.model.embed_tokens
422+
assert proposer.model.model.embed_tokens is not target_model.model.embed_tokens
420423
else:
421-
# When pp_size is 1 and the draft and target models have
422-
# embed_tokens of the same shape, they should be shared.
423-
assert proposer.model.model.embed_tokens == target_model.model.embed_tokens
424+
assert proposer.model.model.embed_tokens is target_model.model.embed_tokens
424425

425426

426427
@pytest.mark.parametrize("method", ["eagle", "eagle3"])

tests/v1/spec_decode/test_mtp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def test_mtp_load_model_unified(mock_get_model, mock_get_layers, mock_get_pp_gro
6767
mock_model = mock.MagicMock()
6868
mock_model.model.embed_tokens.weight.shape = (131072, 4096)
6969
mock_get_model.return_value = mock_model
70+
# MTP does not have its own embed_tokens or lm_head
71+
# so it should share them with the target model
72+
mock_model.has_own_embed_tokens = False
73+
mock_model.has_own_lm_head = False
7074

7175
target_attn_layers = {"target_attn_1": mock.MagicMock()}
7276
all_attn_layers = {**target_attn_layers, "draft_attn_1": mock.MagicMock()}

vllm/model_executor/models/deepseek_eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from vllm.utils import init_logger
2828

29-
from .utils import AutoWeightsLoader, maybe_prefix
29+
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
3030

3131
logger = init_logger(__name__)
3232

@@ -250,6 +250,7 @@ def transform(inputs):
250250
name, loaded_weight = inputs
251251
if "lm_head" not in name:
252252
name = "model." + name
253+
process_eagle_weight(self, name)
253254
return name, loaded_weight
254255

255256
loader = AutoWeightsLoader(

vllm/model_executor/models/deepseek_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
)
8686
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
8787

88-
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
88+
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
8989
from .utils import (
9090
PPMissingLayer,
9191
is_pp_missing_parameter,
@@ -1311,7 +1311,7 @@ def update_physical_experts_metadata(
13111311

13121312

13131313
class DeepseekV2ForCausalLM(
1314-
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
1314+
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
13151315
):
13161316
packed_modules_mapping = {
13171317
"gate_up_proj": ["gate_proj", "up_proj"],

vllm/model_executor/models/interfaces.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -932,13 +932,73 @@ def supports_transcription(
932932

933933

934934
@runtime_checkable
935-
class SupportsEagle3(Protocol):
935+
class SupportsEagleBase(Protocol):
936+
"""Base interface for models that support EAGLE-based speculative decoding."""
937+
938+
has_own_lm_head: bool = False
939+
"""
940+
A flag that indicates this model has trained its own lm_head.
941+
"""
942+
943+
has_own_embed_tokens: bool = False
944+
"""
945+
A flag that indicates this model has trained its own input embeddings.
946+
"""
947+
948+
949+
@overload
950+
def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...
951+
952+
953+
@overload
954+
def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...
955+
956+
957+
def supports_any_eagle(
958+
model: type[object] | object,
959+
) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
960+
"""Check if model supports any EAGLE variant (1, 2, or 3)."""
961+
return supports_eagle(model) or supports_eagle3(model)
962+
963+
964+
@runtime_checkable
965+
class SupportsEagle(SupportsEagleBase, Protocol):
966+
"""The interface required for models that support
967+
EAGLE-1 and EAGLE-2 speculative decoding."""
968+
969+
supports_eagle: ClassVar[Literal[True]] = True
970+
"""
971+
A flag that indicates this model supports EAGLE-1 and EAGLE-2
972+
speculative decoding.
973+
974+
Note:
975+
There is no need to redefine this flag if this class is in the
976+
MRO of your model class.
977+
"""
978+
979+
980+
@overload
981+
def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...
982+
983+
984+
@overload
985+
def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...
986+
987+
988+
def supports_eagle(
989+
model: type[object] | object,
990+
) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
991+
return isinstance(model, SupportsEagle)
992+
993+
994+
@runtime_checkable
995+
class SupportsEagle3(SupportsEagleBase, Protocol):
936996
"""The interface required for models that support
937-
EAGLE3 speculative decoding."""
997+
EAGLE-3 speculative decoding."""
938998

939999
supports_eagle3: ClassVar[Literal[True]] = True
9401000
"""
941-
A flag that indicates this model supports EAGLE3
1001+
A flag that indicates this model supports EAGLE-3
9421002
speculative decoding.
9431003
9441004
Note:
@@ -949,7 +1009,7 @@ class SupportsEagle3(Protocol):
9491009
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
9501010
"""
9511011
Set which layers should output auxiliary
952-
hidden states for EAGLE3.
1012+
hidden states for EAGLE-3.
9531013
9541014
Args:
9551015
layers: Tuple of layer indices that should output auxiliary
@@ -960,7 +1020,7 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
9601020
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
9611021
"""
9621022
Get the layer indices that should output auxiliary hidden states
963-
for EAGLE3.
1023+
for EAGLE-3.
9641024
9651025
Returns:
9661026
Tuple of layer indices for auxiliary hidden state outputs.

vllm/model_executor/models/llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
)
5959
from vllm.sequence import IntermediateTensors
6060

61-
from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
61+
from .interfaces import SupportsEagle, SupportsEagle3, SupportsLoRA, SupportsPP
6262
from .utils import (
6363
AutoWeightsLoader,
6464
PPMissingLayer,
@@ -529,7 +529,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
529529
return loaded_params
530530

531531

532-
class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
532+
class LlamaForCausalLM(
533+
nn.Module, SupportsLoRA, SupportsPP, SupportsEagle, SupportsEagle3
534+
):
533535
packed_modules_mapping = {
534536
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
535537
"gate_up_proj": ["gate_proj", "up_proj"],

vllm/model_executor/models/llama4_eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from vllm.model_executor.models.utils import extract_layer_index
3636

3737
from .interfaces import SupportsMultiModal
38-
from .utils import AutoWeightsLoader, maybe_prefix
38+
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
3939

4040
logger = init_logger(__name__)
4141

@@ -212,6 +212,7 @@ def transform(inputs):
212212
name, weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
213213
if "lm_head" not in name:
214214
name = "model." + name
215+
process_eagle_weight(self, name)
215216
return name, weight
216217

217218
loader = AutoWeightsLoader(

vllm/model_executor/models/llama_eagle.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1818
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
1919

20-
from .utils import AutoWeightsLoader, maybe_prefix
20+
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
2121

2222
logger = init_logger(__name__)
2323

@@ -179,6 +179,7 @@ def transform(inputs):
179179
name, loaded_weight = inputs
180180
if "lm_head" not in name:
181181
name = "model." + name
182+
process_eagle_weight(self, name)
182183
return name, loaded_weight
183184

184185
loader = AutoWeightsLoader(

vllm/model_executor/models/llama_eagle3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from vllm.multimodal import MULTIMODAL_REGISTRY
2424
from vllm.multimodal.inputs import NestedTensors
2525

26-
from .utils import AutoWeightsLoader, maybe_prefix
26+
from .utils import AutoWeightsLoader, maybe_prefix, process_eagle_weight
2727

2828
logger = init_logger(__name__)
2929

@@ -324,6 +324,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
324324
if "embed_tokens" in name:
325325
includes_embed_tokens = True
326326
model_weights[name] = loaded_weight
327+
process_eagle_weight(self, name)
327328

328329
skip_substrs = []
329330
if not includes_draft_id_mapping:

vllm/model_executor/models/minicpm_eagle.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4444
from vllm.sequence import IntermediateTensors
4545

46-
from .interfaces import SupportsLoRA, SupportsPP
46+
from .interfaces import SupportsEagle, SupportsLoRA, SupportsPP
4747
from .minicpm import MiniCPMAttention as EagleMiniCPMAttention
4848
from .minicpm import MiniCPMMLP as EagleMiniCPMMLP
4949
from .minicpm import MiniCPMMoE as EagleMiniCPMMoE
@@ -52,6 +52,7 @@
5252
is_pp_missing_parameter,
5353
make_empty_intermediate_tensors_factory,
5454
maybe_prefix,
55+
process_eagle_weight,
5556
)
5657

5758

@@ -289,7 +290,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
289290
return loaded_params
290291

291292

292-
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
293+
class EagleMiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle):
293294
packed_modules_mapping = {
294295
"qkv_proj": [
295296
"q_proj",
@@ -376,8 +377,13 @@ def compute_logits(
376377
return logits
377378

378379
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
380+
def transform(inputs):
381+
name, loaded_weight = inputs
382+
process_eagle_weight(self, name)
383+
return name, loaded_weight
384+
379385
loader = AutoWeightsLoader(
380386
self,
381387
skip_prefixes=(["lm_head."] if self.config.tie_word_embeddings else None),
382388
)
383-
return loader.load_weights(weights)
389+
return loader.load_weights(map(transform, weights))

0 commit comments

Comments
 (0)