@@ -24,6 +24,12 @@ def from_pretrained(
24
24
config_dict , _ = cls .get_config_dict (pretrained_model_name_or_path ,
25
25
** kwargs )
26
26
27
+ vllm_config = cls .extract_vllm_speculative_config (config_dict )
28
+ return cls (** vllm_config )
29
+
30
+ @classmethod
31
+ def extract_vllm_speculative_config (
32
+ cls , config_dict : dict [str , Any ]) -> dict [str , Any ]:
27
33
speculators_model_type = config_dict .get ("speculators_model_type" )
28
34
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES :
29
35
raise ValueError (
@@ -34,11 +40,12 @@ def from_pretrained(
34
40
# TODO: @dsikka - use speculators pydantic model to validate
35
41
cls .validate_speculators_config (config_dict = config_dict )
36
42
# Convert from speculators config -> format that can be ingested by vLLM
37
- vllm_config = cls .convert_speculators_to_vllm (config_dict = config_dict )
43
+ vllm_config = cls .build_vllm_speculative_config (
44
+ config_dict = config_dict )
38
45
# Apply anything specific to the supported algorithm
39
46
algo_updater = SUPPORTED_SPECULATORS_TYPES [speculators_model_type ]
40
47
algo_updater (config_dict = config_dict , vllm_config = vllm_config )
41
- return cls ( ** vllm_config )
48
+ return vllm_config
42
49
43
50
@classmethod
44
51
def validate_speculators_config (cls , config_dict : dict [str , Any ]) -> None :
@@ -60,32 +67,45 @@ def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
60
67
"'transformer_layer_config' must be a dictionary if provided" )
61
68
62
69
@classmethod
63
- def convert_speculators_to_vllm (
70
+ def build_vllm_speculative_config (
64
71
cls , config_dict : dict [str , Any ]) -> dict [str , Any ]:
65
72
"""
66
- Convert speculators config format to vLLM format.
67
-
68
- This method handles the translation of field names and structure
69
- between speculators and vLLM formats.
70
-
73
+ Build vLLM-compatible speculative configuration from speculators format.
74
+
75
+ This method extracts and transforms speculative configuration from the
76
+ speculators format into the structure expected by vLLM.
77
+
78
+ Args:
79
+ config_dict: Configuration dictionary in speculators format
80
+
71
81
Returns:
72
- Dictionary with vLLM-compatible configuration
82
+ Dictionary with vLLM-compatible speculative configuration
73
83
"""
74
- # Currently we only support one proposal method
84
+ # Extract speculators configuration
75
85
spec_config = config_dict ["speculators_config" ]
76
- first_method = spec_config .get ("proposal_methods" )[0 ]
77
- num_lookahead_tokens = first_method .get ("speculative_tokens" )
78
86
79
- if num_lookahead_tokens is None :
87
+ # Currently we only support one proposal method
88
+ proposal_methods = spec_config .get ("proposal_methods" )
89
+ if not proposal_methods :
90
+ raise ValueError ("No proposal methods found in speculators config" )
91
+
92
+ first_method = proposal_methods [0 ]
93
+ num_speculative_tokens = first_method .get ("speculative_tokens" )
94
+
95
+ if num_speculative_tokens is None :
80
96
raise ValueError (
81
97
"Missing 'speculative_tokens' in proposal method. "
82
98
f"Got: { first_method } " )
83
99
84
- # Build base vLLM config
100
+ # Build base vLLM speculative configuration
85
101
vllm_config = {
86
102
"method" : config_dict .get ("speculators_model_type" ),
87
- "num_lookahead_tokens " : num_lookahead_tokens ,
103
+ "num_speculative_tokens " : num_speculative_tokens ,
88
104
"target_model" : spec_config .get ("verifier" )["name_or_path" ]
89
105
}
90
- vllm_config .update (config_dict ["transformer_layer_config" ])
106
+
107
+ # Merge transformer layer configuration if present
108
+ transformer_config = config_dict .get ("transformer_layer_config" , {})
109
+ vllm_config .update (transformer_config )
110
+
91
111
return vllm_config
0 commit comments