Skip to content

Commit 05f6846

Browse files
authored
Support llama3 eagle3 head with llama4 verifier (#25961)
Signed-off-by: rahul-tuli <[email protected]> Signed-off-by: Rahul Tuli <[email protected]>
1 parent 20db99c commit 05f6846

File tree

5 files changed

+83
-8
lines changed

5 files changed

+83
-8
lines changed

vllm/model_executor/models/llama.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
604604
self.model.aux_hidden_state_layers = layers
605605

606606
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
607+
"""Override to return default layers for Llama
608+
609+
Note: The GPU model runner will override this with layers from
610+
the speculative config if available, providing dynamic configuration.
611+
"""
607612
num_layers = len(self.model.layers)
608613
return (2, num_layers // 2, num_layers - 3)
609614

vllm/model_executor/models/llama_eagle3.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2323
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
24+
from vllm.multimodal.inputs import NestedTensors
2425

2526
from .utils import AutoWeightsLoader, maybe_prefix
2627

@@ -241,7 +242,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
241242
requires_grad=False,
242243
)
243244

244-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
245+
def get_input_embeddings(
246+
self,
247+
input_ids: torch.Tensor,
248+
multimodal_embeddings: Optional[NestedTensors] = None,
249+
is_multimodal: Optional[torch.Tensor] = None,
250+
) -> torch.Tensor:
245251
return self.model.get_input_embeddings(input_ids)
246252

247253
def forward(

vllm/model_executor/models/mllama4.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@
6464
from vllm.sequence import IntermediateTensors
6565
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6666

67-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
67+
from .interfaces import (
68+
MultiModalEmbeddings,
69+
SupportsEagle3,
70+
SupportsMultiModal,
71+
SupportsPP,
72+
)
6873
from .llama4 import Llama4ForCausalLM
6974
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
7075
from .vision import run_dp_sharded_vision_model
@@ -717,7 +722,9 @@ def get_dummy_mm_data(
717722
info=Mllama4ProcessingInfo,
718723
dummy_inputs=Mllama4DummyInputsBuilder,
719724
)
720-
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
725+
class Llama4ForConditionalGeneration(
726+
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
727+
):
721728
packed_modules_mapping = {
722729
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
723730
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -767,6 +774,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
767774
self.language_model.make_empty_intermediate_tensors
768775
)
769776

777+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
778+
"""Set which layers should output auxiliary hidden states for EAGLE3."""
779+
# Delegate to underlying language model (Llama4ForCausalLM)
780+
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
781+
self.language_model.set_aux_hidden_state_layers(layers)
782+
783+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
784+
"""Get the layer indices for auxiliary hidden state outputs.
785+
786+
Note: The GPU model runner will override this with layers from
787+
the speculative config if available, providing dynamic configuration.
788+
"""
789+
# Delegate to underlying language model (Llama4ForCausalLM)
790+
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
791+
return self.language_model.get_eagle3_aux_hidden_state_layers()
792+
770793
def _parse_and_validate_image_input(
771794
self, **kwargs: object
772795
) -> Optional[Llama4ImagePatchInputs]:

vllm/transformers_utils/configs/speculators/algos.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,18 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
2121
- draft_vocab_size: Size of the draft model's vocabulary
2222
- target_hidden_size: Hidden size of the target model
2323
- norm_before_residual: Whether to apply norm before residual connection
24+
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
25+
model to use as auxiliary inputs for the Eagle3 drafter. These layers
26+
provide intermediate hidden states that help the drafter make better
27+
predictions. This is the standard field used in Eagle3 checkpoints.
2428
"""
2529

2630
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
2731
if config_dict.get("target_hidden_size") is not None:
2832
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
2933
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
3034
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
35+
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
36+
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
37+
"eagle_aux_hidden_state_layer_ids"
38+
]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,15 +2943,24 @@ def load_model(self, eep_scale_up: bool = False) -> None:
29432943
logger.info("Loading drafter model...")
29442944
self.drafter.load_model(self.model)
29452945
if self.use_aux_hidden_state_outputs:
2946-
if supports_eagle3(self.model):
2947-
self.model.set_aux_hidden_state_layers(
2948-
self.model.get_eagle3_aux_hidden_state_layers()
2949-
)
2950-
else:
2946+
if not supports_eagle3(self.model):
29512947
raise RuntimeError(
29522948
"Model does not support EAGLE3 interface but "
29532949
"aux_hidden_state_outputs was requested"
29542950
)
2951+
2952+
# Try to get auxiliary layers from speculative config,
2953+
# otherwise use model's default layers
2954+
aux_layers = self._get_eagle3_aux_layers_from_config()
2955+
if aux_layers:
2956+
logger.info(
2957+
"Using auxiliary layers from speculative config: %s",
2958+
aux_layers,
2959+
)
2960+
else:
2961+
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
2962+
2963+
self.model.set_aux_hidden_state_layers(aux_layers)
29552964
time_after_load = time.perf_counter()
29562965
self.model_memory_usage = m.consumed_memory
29572966
logger.info(
@@ -3006,6 +3015,30 @@ def load_model(self, eep_scale_up: bool = False) -> None:
30063015
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
30073016
)
30083017

3018+
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
3019+
"""Extract Eagle3 auxiliary layer indices from speculative config.
3020+
3021+
These indices specify which hidden states from the base model should
3022+
be used as auxiliary inputs for the Eagle3 drafter model during
3023+
speculative decoding.
3024+
3025+
Returns:
3026+
Tuple of layer indices if found in draft model config,
3027+
None otherwise.
3028+
"""
3029+
if not (self.speculative_config and self.speculative_config.draft_model_config):
3030+
return None
3031+
3032+
hf_config = self.speculative_config.draft_model_config.hf_config
3033+
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
3034+
return None
3035+
3036+
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
3037+
if layer_ids and isinstance(layer_ids, (list, tuple)):
3038+
return tuple(layer_ids)
3039+
3040+
return None
3041+
30093042
def reload_weights(self) -> None:
30103043
assert getattr(self, "model", None) is not None, (
30113044
"Cannot reload weights before model is loaded."

0 commit comments

Comments
 (0)