Skip to content

Commit b1c9aa3

Browse files
authored
[Bugfix] [SpecDecode] Default speculative_draft_tensor_parallel_size to 1 when using MLPSpeculator (#7105)
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 179a6a3 commit b1c9aa3

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

vllm/config.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,7 +1068,7 @@ def maybe_create_spec_config(
10681068
draft_parallel_config = (
10691069
SpeculativeConfig.create_draft_parallel_config(
10701070
target_parallel_config,
1071-
speculative_draft_tensor_parallel_size))
1071+
speculative_draft_tensor_parallel_size, draft_hf_config))
10721072

10731073
if num_speculative_tokens is None:
10741074
raise ValueError(
@@ -1136,15 +1136,23 @@ def _maybe_override_draft_max_model_len(
11361136
@staticmethod
11371137
def create_draft_parallel_config(
11381138
target_parallel_config: ParallelConfig,
1139-
speculative_draft_tensor_parallel_size: Optional[int]
1139+
speculative_draft_tensor_parallel_size: Optional[int],
1140+
draft_hf_config: PretrainedConfig,
11401141
) -> ParallelConfig:
11411142
"""Create a parallel config for use by the draft worker.
11421143
11431144
This is mostly a copy of the target parallel config, except the tp_size.
11441145
"""
11451146
if speculative_draft_tensor_parallel_size is None:
1146-
speculative_draft_tensor_parallel_size = \
1147-
target_parallel_config.tensor_parallel_size
1147+
if draft_hf_config.model_type == "mlp_speculator":
1148+
speculative_draft_tensor_parallel_size = 1
1149+
if target_parallel_config.tensor_parallel_size > 1:
1150+
logger.warning(
1151+
"MLPSpeculator cannot currently be run with tp>1; "
1152+
"setting speculative_draft_tensor_parallel_size=1")
1153+
else:
1154+
speculative_draft_tensor_parallel_size = \
1155+
target_parallel_config.tensor_parallel_size
11481156
elif speculative_draft_tensor_parallel_size != 1:
11491157
# TODO(wooyeon): allow tp values larger than 1
11501158
raise ValueError(

0 commit comments

Comments
 (0)