|
64 | 64 | from vllm.sequence import IntermediateTensors
|
65 | 65 | from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
66 | 66 |
|
67 |
| -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
| 67 | +from .interfaces import ( |
| 68 | + MultiModalEmbeddings, |
| 69 | + SupportsEagle3, |
| 70 | + SupportsMultiModal, |
| 71 | + SupportsPP, |
| 72 | +) |
68 | 73 | from .llama4 import Llama4ForCausalLM
|
69 | 74 | from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
|
70 | 75 | from .vision import run_dp_sharded_vision_model
|
@@ -717,7 +722,9 @@ def get_dummy_mm_data(
|
717 | 722 | info=Mllama4ProcessingInfo,
|
718 | 723 | dummy_inputs=Mllama4DummyInputsBuilder,
|
719 | 724 | )
|
720 |
| -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): |
| 725 | +class Llama4ForConditionalGeneration( |
| 726 | + nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3 |
| 727 | +): |
721 | 728 | packed_modules_mapping = {
|
722 | 729 | "qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
723 | 730 | "gate_up_proj": ["gate_proj", "up_proj"],
|
@@ -767,6 +774,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
767 | 774 | self.language_model.make_empty_intermediate_tensors
|
768 | 775 | )
|
769 | 776 |
|
| 777 | + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: |
| 778 | + """Set which layers should output auxiliary hidden states for EAGLE3.""" |
| 779 | + # Delegate to underlying language model (Llama4ForCausalLM) |
| 780 | + assert hasattr(self.language_model, "set_aux_hidden_state_layers") |
| 781 | + self.language_model.set_aux_hidden_state_layers(layers) |
| 782 | + |
| 783 | + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: |
| 784 | + """Get the layer indices for auxiliary hidden state outputs. |
| 785 | +
|
| 786 | + Note: The GPU model runner will override this with layers from |
| 787 | + the speculative config if available, providing dynamic configuration. |
| 788 | + """ |
| 789 | + # Delegate to underlying language model (Llama4ForCausalLM) |
| 790 | + assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers") |
| 791 | + return self.language_model.get_eagle3_aux_hidden_state_layers() |
| 792 | + |
770 | 793 | def _parse_and_validate_image_input(
|
771 | 794 | self, **kwargs: object
|
772 | 795 | ) -> Optional[Llama4ImagePatchInputs]:
|
|
0 commit comments