27
27
from vllm .sequence import IntermediateTensors
28
28
from vllm .utils import cdiv
29
29
30
- from .interfaces import SupportsPP
30
+ from .interfaces import SupportsEagle3 , SupportsPP
31
31
from .utils import (AutoWeightsLoader , WeightsMapper , extract_layer_index ,
32
32
is_pp_missing_parameter ,
33
33
make_empty_intermediate_tensors_factory , make_layers ,
@@ -238,6 +238,7 @@ def __init__(
238
238
self .make_empty_intermediate_tensors = (
239
239
make_empty_intermediate_tensors_factory (
240
240
["hidden_states" , "residual" ], self .config .hidden_size ))
241
+ self .aux_hidden_state_layers = tuple [int , ...]()
241
242
242
243
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
243
244
return self .embedding (input_ids )
@@ -261,15 +262,22 @@ def forward(
261
262
x = intermediate_tensors ["hidden_states" ]
262
263
residual = intermediate_tensors ["residual" ]
263
264
265
+ aux_hidden_states = []
264
266
for i in range (self .start_layer , self .end_layer ):
265
267
layer = self .layers [i ]
268
+ if i in self .aux_hidden_state_layers :
269
+ aux_hidden_states .append (x if residual is None else x +
270
+ residual )
266
271
x , residual = layer (x , positions , residual )
267
272
if not get_pp_group ().is_last_rank :
268
273
return IntermediateTensors ({
269
274
"hidden_states" : x ,
270
275
"residual" : residual
271
276
})
272
277
x , _ = self .norm (x , residual )
278
+
279
+ if len (aux_hidden_states ) > 0 :
280
+ return x , aux_hidden_states
273
281
return x
274
282
275
283
def _load_weights_mxfp4 (
@@ -610,7 +618,7 @@ def load_weights(self, weights: Iterable[tuple[str,
610
618
weights , stacked_params_mapping )
611
619
612
620
613
- class GptOssForCausalLM (nn .Module , SupportsPP ):
621
+ class GptOssForCausalLM (nn .Module , SupportsPP , SupportsEagle3 ):
614
622
packed_modules_mapping = {"qkv" : ["q_proj" , "k_proj" , "v_proj" ]}
615
623
616
624
hf_to_vllm_mapper = WeightsMapper (
@@ -658,6 +666,13 @@ def __init__(
658
666
self .make_empty_intermediate_tensors = (
659
667
self .model .make_empty_intermediate_tensors )
660
668
669
+ def set_aux_hidden_state_layers (self , layers : tuple [int , ...]) -> None :
670
+ self .model .aux_hidden_state_layers = layers
671
+
672
+ def get_eagle3_aux_hidden_state_layers (self ) -> tuple [int , ...]:
673
+ num_layers = len (self .model .layers )
674
+ return (2 , num_layers // 2 , num_layers - 3 )
675
+
661
676
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
662
677
return self .model .get_input_embeddings (input_ids )
663
678
0 commit comments