@@ -1068,7 +1068,7 @@ def maybe_create_spec_config(
1068
1068
draft_parallel_config = (
1069
1069
SpeculativeConfig .create_draft_parallel_config (
1070
1070
target_parallel_config ,
1071
- speculative_draft_tensor_parallel_size ))
1071
+ speculative_draft_tensor_parallel_size , draft_hf_config ))
1072
1072
1073
1073
if num_speculative_tokens is None :
1074
1074
raise ValueError (
@@ -1136,15 +1136,23 @@ def _maybe_override_draft_max_model_len(
1136
1136
@staticmethod
1137
1137
def create_draft_parallel_config (
1138
1138
target_parallel_config : ParallelConfig ,
1139
- speculative_draft_tensor_parallel_size : Optional [int ]
1139
+ speculative_draft_tensor_parallel_size : Optional [int ],
1140
+ draft_hf_config : PretrainedConfig ,
1140
1141
) -> ParallelConfig :
1141
1142
"""Create a parallel config for use by the draft worker.
1142
1143
1143
1144
This is mostly a copy of the target parallel config, except the tp_size.
1144
1145
"""
1145
1146
if speculative_draft_tensor_parallel_size is None :
1146
- speculative_draft_tensor_parallel_size = \
1147
- target_parallel_config .tensor_parallel_size
1147
+ if draft_hf_config .model_type == "mlp_speculator" :
1148
+ speculative_draft_tensor_parallel_size = 1
1149
+ if target_parallel_config .tensor_parallel_size > 1 :
1150
+ logger .warning (
1151
+ "MLPSpeculator cannot currently be run with tp>1; "
1152
+ "setting speculative_draft_tensor_parallel_size=1" )
1153
+ else :
1154
+ speculative_draft_tensor_parallel_size = \
1155
+ target_parallel_config .tensor_parallel_size
1148
1156
elif speculative_draft_tensor_parallel_size != 1 :
1149
1157
# TODO(wooyeon): allow tp values larger than 1
1150
1158
raise ValueError (
0 commit comments