Skip to content

Commit 21467f9

Browse files
authored
Enable Eagle3 speculative decoding for GPT-OSS model (#25246)
Signed-off-by: Eldar Kurtic <[email protected]>
1 parent f92d952 commit 21467f9

File tree

3 files changed

+41
-12
lines changed

3 files changed

+41
-12
lines changed

vllm/config/speculative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def _verify_args(self) -> Self:
527527
"speculative decoding is > 1, but got "
528528
f"{self.disable_by_batch_size=}")
529529

530-
eagle3_target_supported = ["llama", "qwen"]
530+
eagle3_target_supported = ["llama", "qwen", "gpt_oss"]
531531
if self.method == "eagle3" and self.target_model_config and not any(
532532
supported_model in
533533
self.target_model_config.hf_text_config.model_type

vllm/model_executor/models/gpt_oss.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from vllm.sequence import IntermediateTensors
2828
from vllm.utils import cdiv
2929

30-
from .interfaces import SupportsPP
30+
from .interfaces import SupportsEagle3, SupportsPP
3131
from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index,
3232
is_pp_missing_parameter,
3333
make_empty_intermediate_tensors_factory, make_layers,
@@ -238,6 +238,7 @@ def __init__(
238238
self.make_empty_intermediate_tensors = (
239239
make_empty_intermediate_tensors_factory(
240240
["hidden_states", "residual"], self.config.hidden_size))
241+
self.aux_hidden_state_layers = tuple[int, ...]()
241242

242243
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
243244
return self.embedding(input_ids)
@@ -261,15 +262,22 @@ def forward(
261262
x = intermediate_tensors["hidden_states"]
262263
residual = intermediate_tensors["residual"]
263264

265+
aux_hidden_states = []
264266
for i in range(self.start_layer, self.end_layer):
265267
layer = self.layers[i]
268+
if i in self.aux_hidden_state_layers:
269+
aux_hidden_states.append(x if residual is None else x +
270+
residual)
266271
x, residual = layer(x, positions, residual)
267272
if not get_pp_group().is_last_rank:
268273
return IntermediateTensors({
269274
"hidden_states": x,
270275
"residual": residual
271276
})
272277
x, _ = self.norm(x, residual)
278+
279+
if len(aux_hidden_states) > 0:
280+
return x, aux_hidden_states
273281
return x
274282

275283
def _load_weights_mxfp4(
@@ -610,7 +618,7 @@ def load_weights(self, weights: Iterable[tuple[str,
610618
weights, stacked_params_mapping)
611619

612620

613-
class GptOssForCausalLM(nn.Module, SupportsPP):
621+
class GptOssForCausalLM(nn.Module, SupportsPP, SupportsEagle3):
614622
packed_modules_mapping = {"qkv": ["q_proj", "k_proj", "v_proj"]}
615623

616624
hf_to_vllm_mapper = WeightsMapper(
@@ -658,6 +666,13 @@ def __init__(
658666
self.make_empty_intermediate_tensors = (
659667
self.model.make_empty_intermediate_tensors)
660668

669+
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
670+
self.model.aux_hidden_state_layers = layers
671+
672+
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
673+
num_layers = len(self.model.layers)
674+
return (2, num_layers // 2, num_layers - 3)
675+
661676
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
662677
return self.model.get_input_embeddings(input_ids)
663678

vllm/v1/spec_decode/eagle.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -823,15 +823,29 @@ def load_model(self, target_model: nn.Module) -> None:
823823
else:
824824
target_language_model = target_model
825825
# share embed_tokens with the target model if needed
826-
if get_pp_group().world_size == 1 \
827-
and self.model.model.embed_tokens.weight.shape \
828-
== target_language_model.model.embed_tokens.weight.shape:
829-
logger.info(
830-
"Assuming the EAGLE head shares the same vocab embedding"
831-
" with the target model.")
832-
del self.model.model.embed_tokens
833-
self.model.model.embed_tokens = (
834-
target_language_model.model.embed_tokens)
826+
if get_pp_group().world_size == 1:
827+
if hasattr(target_language_model.model, 'embed_tokens'):
828+
target_embed_tokens = target_language_model.model.embed_tokens
829+
elif hasattr(target_language_model.model, 'embedding'):
830+
target_embed_tokens = target_language_model.model.embedding
831+
else:
832+
raise AttributeError(
833+
"Target model does not have 'embed_tokens' or 'embedding' "
834+
"attribute")
835+
836+
# Check if shapes match and we found the embedding
837+
eagle_shape = self.model.model.embed_tokens.weight.shape
838+
target_shape = target_embed_tokens.weight.shape
839+
if eagle_shape == target_shape:
840+
logger.info(
841+
"Assuming the EAGLE head shares the same vocab embedding"
842+
" with the target model.")
843+
del self.model.model.embed_tokens
844+
self.model.model.embed_tokens = target_embed_tokens
845+
else:
846+
logger.info(
847+
"The EAGLE head's vocab embedding will be loaded separately"
848+
" from the target model.")
835849
else:
836850
logger.info(
837851
"The EAGLE head's vocab embedding will be loaded separately"

0 commit comments

Comments
 (0)