Skip to content

Commit 5233423

Browse files
rahul-tulisouthfreebird
authored andcommitted
Support llama3 eagle3 head with llama4 verifier (vllm-project#25961)
Signed-off-by: rahul-tuli <[email protected]> Signed-off-by: Rahul Tuli <[email protected]>
1 parent 2b21e38 commit 5233423

File tree

5 files changed

+83
-8
lines changed

5 files changed

+83
-8
lines changed

vllm/model_executor/models/llama.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
604604
self.model.aux_hidden_state_layers = layers
605605

606606
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
607+
"""Override to return default layers for Llama
608+
609+
Note: The GPU model runner will override this with layers from
610+
the speculative config if available, providing dynamic configuration.
611+
"""
607612
num_layers = len(self.model.layers)
608613
return (2, num_layers // 2, num_layers - 3)
609614

vllm/model_executor/models/llama_eagle3.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2323
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
24+
from vllm.multimodal.inputs import NestedTensors
2425

2526
from .utils import AutoWeightsLoader, maybe_prefix
2627

@@ -241,7 +242,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
241242
requires_grad=False,
242243
)
243244

244-
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
245+
def get_input_embeddings(
246+
self,
247+
input_ids: torch.Tensor,
248+
multimodal_embeddings: Optional[NestedTensors] = None,
249+
is_multimodal: Optional[torch.Tensor] = None,
250+
) -> torch.Tensor:
245251
return self.model.get_input_embeddings(input_ids)
246252

247253
def forward(

vllm/model_executor/models/mllama4.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,12 @@
6464
from vllm.sequence import IntermediateTensors
6565
from vllm.utils.tensor_schema import TensorSchema, TensorShape
6666

67-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
67+
from .interfaces import (
68+
MultiModalEmbeddings,
69+
SupportsEagle3,
70+
SupportsMultiModal,
71+
SupportsPP,
72+
)
6873
from .llama4 import Llama4ForCausalLM
6974
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
7075
from .vision import run_dp_sharded_vision_model
@@ -717,7 +722,9 @@ def get_dummy_mm_data(
717722
info=Mllama4ProcessingInfo,
718723
dummy_inputs=Mllama4DummyInputsBuilder,
719724
)
720-
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
725+
class Llama4ForConditionalGeneration(
726+
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
727+
):
721728
packed_modules_mapping = {
722729
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
723730
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -767,6 +774,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
767774
self.language_model.make_empty_intermediate_tensors
768775
)
769776

777+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
778+
"""Set which layers should output auxiliary hidden states for EAGLE3."""
779+
# Delegate to underlying language model (Llama4ForCausalLM)
780+
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
781+
self.language_model.set_aux_hidden_state_layers(layers)
782+
783+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
784+
"""Get the layer indices for auxiliary hidden state outputs.
785+
786+
Note: The GPU model runner will override this with layers from
787+
the speculative config if available, providing dynamic configuration.
788+
"""
789+
# Delegate to underlying language model (Llama4ForCausalLM)
790+
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
791+
return self.language_model.get_eagle3_aux_hidden_state_layers()
792+
770793
def _parse_and_validate_image_input(
771794
self, **kwargs: object
772795
) -> Optional[Llama4ImagePatchInputs]:

vllm/transformers_utils/configs/speculators/algos.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,18 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
2121
- draft_vocab_size: Size of the draft model's vocabulary
2222
- target_hidden_size: Hidden size of the target model
2323
- norm_before_residual: Whether to apply norm before residual connection
24+
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
25+
model to use as auxiliary inputs for the Eagle3 drafter. These layers
26+
provide intermediate hidden states that help the drafter make better
27+
predictions. This is the standard field used in Eagle3 checkpoints.
2428
"""
2529

2630
vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
2731
if config_dict.get("target_hidden_size") is not None:
2832
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
2933
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
3034
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
35+
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
36+
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
37+
"eagle_aux_hidden_state_layer_ids"
38+
]

vllm/v1/worker/gpu_model_runner.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2945,15 +2945,24 @@ def load_model(self, eep_scale_up: bool = False) -> None:
29452945
logger.info("Loading drafter model...")
29462946
self.drafter.load_model(self.model)
29472947
if self.use_aux_hidden_state_outputs:
2948-
if supports_eagle3(self.model):
2949-
self.model.set_aux_hidden_state_layers(
2950-
self.model.get_eagle3_aux_hidden_state_layers()
2951-
)
2952-
else:
2948+
if not supports_eagle3(self.model):
29532949
raise RuntimeError(
29542950
"Model does not support EAGLE3 interface but "
29552951
"aux_hidden_state_outputs was requested"
29562952
)
2953+
2954+
# Try to get auxiliary layers from speculative config,
2955+
# otherwise use model's default layers
2956+
aux_layers = self._get_eagle3_aux_layers_from_config()
2957+
if aux_layers:
2958+
logger.info(
2959+
"Using auxiliary layers from speculative config: %s",
2960+
aux_layers,
2961+
)
2962+
else:
2963+
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()
2964+
2965+
self.model.set_aux_hidden_state_layers(aux_layers)
29572966
time_after_load = time.perf_counter()
29582967
self.model_memory_usage = m.consumed_memory
29592968
logger.info(
@@ -3008,6 +3017,30 @@ def load_model(self, eep_scale_up: bool = False) -> None:
30083017
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
30093018
)
30103019

3020+
def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
3021+
"""Extract Eagle3 auxiliary layer indices from speculative config.
3022+
3023+
These indices specify which hidden states from the base model should
3024+
be used as auxiliary inputs for the Eagle3 drafter model during
3025+
speculative decoding.
3026+
3027+
Returns:
3028+
Tuple of layer indices if found in draft model config,
3029+
None otherwise.
3030+
"""
3031+
if not (self.speculative_config and self.speculative_config.draft_model_config):
3032+
return None
3033+
3034+
hf_config = self.speculative_config.draft_model_config.hf_config
3035+
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
3036+
return None
3037+
3038+
layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
3039+
if layer_ids and isinstance(layer_ids, (list, tuple)):
3040+
return tuple(layer_ids)
3041+
3042+
return None
3043+
30113044
def reload_weights(self) -> None:
30123045
assert getattr(self, "model", None) is not None, (
30133046
"Cannot reload weights before model is loaded."

0 commit comments

Comments
 (0)