@@ -185,7 +185,7 @@ def __repr__(self) -> str:
185185
186186 outputs = await generate_with_vllm (
187187 llm ,
188- content . max_new_tokens ,
188+ [ generations [ iter [ 1 ]]. remaining_tokens for iter in iter_prompts ] ,
189189 content .temperature ,
190190 content .stop_sequences ,
191191 content .return_token_log_probs ,
@@ -260,7 +260,10 @@ def tool_func(text: str, past_context: Optional[str]):
260260 gen_item .generated_text += new_text
261261
262262 # If we didn't just execute a tool, we're done
263- if not gen_item .generated_text .endswith (tool .tool_context_end ):
263+ if (
264+ not gen_item .generated_text .endswith (tool .tool_context_end )
265+ or gen_item .remaining_tokens <= 0
266+ ):
264267 gen_item .completed = True
265268 continue
266269
@@ -316,7 +319,7 @@ async def batch_inference():
316319
317320 outputs = await generate_with_vllm (
318321 llm ,
319- content .max_new_tokens ,
322+ [ content .max_new_tokens ] * len ( prompts ) ,
320323 content .temperature ,
321324 content .stop_sequences ,
322325 content .return_token_log_probs ,
@@ -358,24 +361,23 @@ async def generate_with_vllm(
358361 top_p ,
359362 prompts ,
360363 bar ,
361- ) -> List [CompletionOutput ]:
364+ ) -> List [CompletionOutput ]: # pragma: no cover
362365 from vllm import SamplingParams
363366
364367 # Add the requests to the engine.
365- sampling_params = SamplingParams (
366- max_tokens = max_new_tokens ,
367- temperature = temperature ,
368- stop = stop_sequences ,
369- logprobs = 1 if return_token_log_probs else None ,
370- presence_penalty = presence_penalty or 0.0 ,
371- frequency_penalty = frequency_penalty or 0.0 ,
372- top_k = top_k or - 1 ,
373- top_p = top_p or 1.0 ,
374- )
375-
376368 results_generators = []
377- for prompt in prompts :
369+ for idx , prompt in enumerate ( prompts ) :
378370 request_id = random_uuid ()
371+ sampling_params = SamplingParams (
372+ max_tokens = max_new_tokens [idx ],
373+ temperature = temperature ,
374+ stop = stop_sequences ,
375+ logprobs = 1 if return_token_log_probs else None ,
376+ presence_penalty = presence_penalty or 0.0 ,
377+ frequency_penalty = frequency_penalty or 0.0 ,
378+ top_k = top_k or - 1 ,
379+ top_p = top_p or 1.0 ,
380+ )
379381 results_generator = await engine .add_request (
380382 request_id , prompt , sampling_params , None , time .monotonic ()
381383 )
0 commit comments