Skip to content

Commit c438b29

Browse files
rahul-tuliclaude
andauthored
feat: Enable engine-level arguments with speculators models (#25250)
Signed-off-by: Rahul Tuli <[email protected]> Co-authored-by: Claude <[email protected]>
1 parent 0ff8ebb commit c438b29

File tree

5 files changed

+121
-78
lines changed

5 files changed

+121
-78
lines changed

tests/speculative_decoding/speculators/test_eagle3.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,52 @@
33
import pytest
44
import torch
55

6+
from vllm.config import SpeculativeConfig
67
from vllm.model_executor.models.interfaces import supports_eagle3
78

89

9-
@pytest.mark.parametrize(
10-
"model_path",
11-
[("nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized")])
12-
def test_llama(vllm_runner, example_prompts, model_path, monkeypatch):
10+
@pytest.mark.parametrize("model_path", [
11+
pytest.param(
12+
"nm-testing/SpeculatorLlama3-1-8B-Eagle3-converted-0717-quantized",
13+
id="llama3-eagle3-speculator"),
14+
pytest.param(
15+
"nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized",
16+
id="qwen3-eagle3-speculator"),
17+
])
18+
def test_eagle3_speculators_model(vllm_runner, example_prompts, model_path,
19+
monkeypatch):
20+
"""
21+
Test Eagle3 speculators models properly initialize speculative decoding.
22+
23+
This test verifies:
24+
1. Eagle3 support is detected for the model
25+
2. Speculative config is automatically initialized from embedded config
26+
3. The draft model path is correctly set to the speculators model
27+
4. Speculative tokens count is valid
28+
5. Text generation works with speculative decoding enabled
29+
"""
1330
# Set environment variable for V1 engine serialization
1431
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
1532

1633
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
34+
# Verify Eagle3 support is detected
1735
eagle3_supported = vllm_model.apply_model(supports_eagle3)
18-
assert eagle3_supported
36+
assert eagle3_supported, f"Eagle3 should be supported for {model_path}"
1937

20-
vllm_outputs = vllm_model.generate_greedy(example_prompts,
21-
max_tokens=20)
22-
print(vllm_outputs)
23-
assert vllm_outputs
38+
vllm_config = vllm_model.llm.llm_engine.vllm_config
2439

40+
assert isinstance(vllm_config.speculative_config, SpeculativeConfig), \
41+
"Speculative config should be initialized for speculators model"
2542

26-
@pytest.mark.parametrize(
27-
"model_path",
28-
[("nm-testing/Speculator-Qwen3-8B-Eagle3-converted-071-quantized")])
29-
def test_qwen(vllm_runner, example_prompts, model_path, monkeypatch):
30-
# Set environment variable for V1 engine serialization
31-
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
43+
spec_config = vllm_config.speculative_config
44+
assert spec_config.num_speculative_tokens > 0, \
45+
(f"Expected positive speculative tokens, "
46+
f"got {spec_config.num_speculative_tokens}")
3247

33-
with vllm_runner(model_path, dtype=torch.bfloat16) as vllm_model:
34-
eagle3_supported = vllm_model.apply_model(supports_eagle3)
35-
assert eagle3_supported
48+
assert spec_config.model == model_path, \
49+
f"Draft model should be {model_path}, got {spec_config.model}"
3650

3751
vllm_outputs = vllm_model.generate_greedy(example_prompts,
3852
max_tokens=20)
39-
print(vllm_outputs)
40-
assert vllm_outputs
53+
assert vllm_outputs, \
54+
f"No outputs generated for speculators model {model_path}"

vllm/config/model.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
ConfigFormat, get_config, get_hf_image_processor_config,
2828
get_hf_text_config, get_pooling_config,
2929
get_sentence_transformer_tokenizer_config, is_encoder_decoder,
30-
is_interleaved, maybe_override_with_speculators_target_model,
31-
try_get_generation_config, try_get_safetensors_metadata,
30+
is_interleaved, try_get_generation_config, try_get_safetensors_metadata,
3231
try_get_tokenizer_config, uses_mrope)
3332
from vllm.transformers_utils.runai_utils import (ObjectStorageModel,
3433
is_runai_obj_uri)
@@ -416,15 +415,6 @@ def __post_init__(
416415

417416
self.maybe_pull_model_tokenizer_for_runai(self.model, self.tokenizer)
418417

419-
if self.runner != "draft":
420-
# If we're not running the draft model, check for speculators config
421-
# If speculators config, set model / tokenizer to be target model
422-
self.model, self.tokenizer = maybe_override_with_speculators_target_model( # noqa: E501
423-
model=self.model,
424-
tokenizer=self.tokenizer,
425-
revision=self.revision,
426-
trust_remote_code=self.trust_remote_code)
427-
428418
if (backend := envs.VLLM_ATTENTION_BACKEND
429419
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
430420
raise ValueError(

vllm/engine/arg_utils.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from vllm.ray.lazy_utils import is_ray_initialized
4242
from vllm.reasoning import ReasoningParserManager
4343
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
44-
from vllm.transformers_utils.config import get_model_path, is_interleaved
44+
from vllm.transformers_utils.config import (get_model_path, is_interleaved,
45+
maybe_override_with_speculators)
4546
from vllm.transformers_utils.utils import check_gguf_file
4647
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
4748
GiB_bytes, get_ip, is_in_ray_actor)
@@ -1082,29 +1083,8 @@ def create_speculative_config(
10821083
provided as a JSON string input via CLI arguments or directly as a
10831084
dictionary from the engine.
10841085
"""
1085-
1086-
from vllm.transformers_utils.config import get_config
1087-
from vllm.transformers_utils.configs.speculators.base import (
1088-
SpeculatorsConfig)
1089-
10901086
if self.speculative_config is None:
1091-
hf_config = get_config(
1092-
self.hf_config_path or target_model_config.model,
1093-
self.trust_remote_code, self.revision, self.code_revision,
1094-
self.config_format)
1095-
1096-
# if loading a SpeculatorsConfig, load the speculative_config
1097-
# details from the config directly
1098-
# no user input required / expected
1099-
if isinstance(hf_config, SpeculatorsConfig):
1100-
# We create one since we don't create one
1101-
self.speculative_config = {}
1102-
self.speculative_config[
1103-
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
1104-
self.speculative_config["model"] = target_model_config.model
1105-
self.speculative_config["method"] = hf_config.method
1106-
else:
1107-
return None
1087+
return None
11081088

11091089
# Note(Shangming): These parameters are not obtained from the cli arg
11101090
# '--speculative-config' and must be passed in when creating the engine
@@ -1139,6 +1119,15 @@ def create_engine_config(
11391119

11401120
device_config = DeviceConfig(
11411121
device=cast(Device, current_platform.device_type))
1122+
1123+
(self.model, self.tokenizer,
1124+
self.speculative_config) = maybe_override_with_speculators(
1125+
model=self.model,
1126+
tokenizer=self.tokenizer,
1127+
revision=self.revision,
1128+
trust_remote_code=self.trust_remote_code,
1129+
vllm_speculative_config=self.speculative_config,
1130+
)
11421131
model_config = self.create_model_config()
11431132

11441133
# * If VLLM_USE_V1 is unset, we enable V1 for "supported features"

vllm/transformers_utils/config.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
463463
return config
464464

465465

466-
def maybe_override_with_speculators_target_model(
466+
def maybe_override_with_speculators(
467467
model: str,
468468
tokenizer: str,
469469
trust_remote_code: bool,
470470
revision: Optional[str] = None,
471+
vllm_speculative_config: Optional[dict[str, Any]] = None,
471472
**kwargs,
472-
) -> tuple[str, str]:
473+
) -> tuple[str, str, Optional[dict[str, Any]]]:
473474
"""
474-
If running a speculators config, override running model with target model
475+
Resolve model configuration when speculators are detected.
476+
477+
Checks if the provided model is a speculators model and if so, extracts
478+
the target model configuration and builds the speculative config.
479+
480+
Args:
481+
model: Model name or path
482+
tokenizer: Tokenizer name or path
483+
trust_remote_code: Whether to trust remote code
484+
revision: Model revision
485+
vllm_speculative_config: Existing vLLM speculative config
486+
487+
Returns:
488+
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
475489
"""
476490
is_gguf = check_gguf_file(model)
477491
if is_gguf:
@@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model(
487501
token=_get_hf_token(),
488502
**kwargs,
489503
)
490-
spec_config = config_dict.get("speculators_config", None)
491-
# Return the target model
492-
if spec_config is not None:
493-
model = tokenizer = spec_config["verifier"]["name_or_path"]
494-
return model, tokenizer
504+
speculators_config = config_dict.get("speculators_config")
505+
506+
if speculators_config is None:
507+
# No speculators config found, return original values
508+
return model, tokenizer, vllm_speculative_config
509+
510+
# Speculators format detected - process overrides
511+
from vllm.transformers_utils.configs.speculators.base import (
512+
SpeculatorsConfig)
513+
514+
vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
515+
config_dict=config_dict)
516+
517+
# Set the draft model to the speculators model
518+
vllm_speculative_config["model"] = model
519+
520+
# Override model and tokenizer with the verifier model from config
521+
verifier_model = speculators_config["verifier"]["name_or_path"]
522+
model = tokenizer = verifier_model
523+
524+
return model, tokenizer, vllm_speculative_config
495525

496526

497527
def get_config(

vllm/transformers_utils/configs/speculators/base.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ def from_pretrained(
2424
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
2525
**kwargs)
2626

27+
vllm_config = cls.extract_vllm_speculative_config(config_dict)
28+
return cls(**vllm_config)
29+
30+
@classmethod
31+
def extract_vllm_speculative_config(
32+
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
2733
speculators_model_type = config_dict.get("speculators_model_type")
2834
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
2935
raise ValueError(
@@ -34,11 +40,12 @@ def from_pretrained(
3440
# TODO: @dsikka - use speculators pydantic model to validate
3541
cls.validate_speculators_config(config_dict=config_dict)
3642
# Convert from speculators config -> format that can be ingested by vLLM
37-
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
43+
vllm_config = cls.build_vllm_speculative_config(
44+
config_dict=config_dict)
3845
# Apply anything specific to the supported algorithm
3946
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
4047
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
41-
return cls(**vllm_config)
48+
return vllm_config
4249

4350
@classmethod
4451
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
@@ -60,32 +67,45 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
6067
"'transformer_layer_config' must be a dictionary if provided")
6168

6269
@classmethod
63-
def convert_speculators_to_vllm(
70+
def build_vllm_speculative_config(
6471
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
6572
"""
66-
Convert speculators config format to vLLM format.
67-
68-
This method handles the translation of field names and structure
69-
between speculators and vLLM formats.
70-
73+
Build vLLM-compatible speculative configuration from speculators format.
74+
75+
This method extracts and transforms speculative configuration from the
76+
speculators format into the structure expected by vLLM.
77+
78+
Args:
79+
config_dict: Configuration dictionary in speculators format
80+
7181
Returns:
72-
Dictionary with vLLM-compatible configuration
82+
Dictionary with vLLM-compatible speculative configuration
7383
"""
74-
# Currently we only support one proposal method
84+
# Extract speculators configuration
7585
spec_config = config_dict["speculators_config"]
76-
first_method = spec_config.get("proposal_methods")[0]
77-
num_lookahead_tokens = first_method.get("speculative_tokens")
7886

79-
if num_lookahead_tokens is None:
87+
# Currently we only support one proposal method
88+
proposal_methods = spec_config.get("proposal_methods")
89+
if not proposal_methods:
90+
raise ValueError("No proposal methods found in speculators config")
91+
92+
first_method = proposal_methods[0]
93+
num_speculative_tokens = first_method.get("speculative_tokens")
94+
95+
if num_speculative_tokens is None:
8096
raise ValueError(
8197
"Missing 'speculative_tokens' in proposal method. "
8298
f"Got: {first_method}")
8399

84-
# Build base vLLM config
100+
# Build base vLLM speculative configuration
85101
vllm_config = {
86102
"method": config_dict.get("speculators_model_type"),
87-
"num_lookahead_tokens": num_lookahead_tokens,
103+
"num_speculative_tokens": num_speculative_tokens,
88104
"target_model": spec_config.get("verifier")["name_or_path"]
89105
}
90-
vllm_config.update(config_dict["transformer_layer_config"])
106+
107+
# Merge transformer layer configuration if present
108+
transformer_config = config_dict.get("transformer_layer_config", {})
109+
vllm_config.update(transformer_config)
110+
91111
return vllm_config

0 commit comments

Comments
 (0)