@@ -187,15 +187,16 @@ def _batch_complete_text(
187187 )
188188
189189 from vllm import RequestOutput , SamplingParams
190+ from vllm .inputs import TokensPrompt
190191
191- prompt_token_ids : list [list [int ]] = model_inputs .input_ids
192+ batch_prompt_token_ids : list [list [int ]] = model_inputs .input_ids
192193 sampling_params : list [SamplingParams ] = []
193194 skip_flag_list : list [bool ] = []
194195 for i , input_ids in enumerate (model_inputs .input_ids ):
195196 remaining = self .model_limit_tokens - len (input_ids )
196197 instance_gen_kwargs = gen_kwargs .copy ()
197198 if remaining <= 0 :
198- prompt_token_ids [i ] = input_ids [:1 ]
199+ batch_prompt_token_ids [i ] = input_ids [:1 ]
199200 instance_gen_kwargs ["max_tokens" ] = 1
200201 msg = (
201202 f"Received input that is longer than `model_limit_tokens = { self .model_limit_tokens } `. "
@@ -208,7 +209,7 @@ def _batch_complete_text(
208209 skip_flag_list .append (remaining <= 0 )
209210
210211 vllm_outputs : list [RequestOutput ] = self .llm .generate (
211- prompt_token_ids = prompt_token_ids ,
212+ [ TokensPrompt ( prompt_token_ids = prompt_token_ids ) for prompt_token_ids in batch_prompt_token_ids ] ,
212213 sampling_params = sampling_params ,
213214 use_tqdm = False ,
214215 )
@@ -309,6 +310,7 @@ def _batch_compute_log_probs(
309310 sequence_length = max ([len (input_ids ) for input_ids in batch_input_ids ])
310311
311312 from vllm import RequestOutput , SamplingParams
313+ from vllm .inputs import TokensPrompt
312314 from vllm .sequence import Logprob
313315
314316 sampling_params = SamplingParams (temperature = 0.0 , max_tokens = 1 , prompt_logprobs = 1 )
@@ -323,7 +325,7 @@ def _batch_compute_log_probs(
323325 for chunk_input_ids in chunk_batch_input_ids
324326 ]
325327 chunk_batch_outputs : list [RequestOutput ] = self .llm .generate (
326- prompt_token_ids = chunk_batch_input_ids ,
328+ [ TokensPrompt ( prompt_token_ids = prompt_token_ids ) for prompt_token_ids in chunk_batch_input_ids ] ,
327329 sampling_params = sampling_params ,
328330 use_tqdm = False ,
329331 )
0 commit comments