Skip to content

Commit 5b0aaf1

Browse files
Tool completion respect num new tokens (#469)
* Tool completion respect num new tokens * more fix * remove unused import * format * empty * no cover
1 parent 234f0a4 commit 5b0aaf1

File tree

1 file changed

+18
-16
lines changed
  • model-engine/model_engine_server/inference/batch_inference

1 file changed

+18
-16
lines changed

model-engine/model_engine_server/inference/batch_inference/vllm_batch.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __repr__(self) -> str:
185185

186186
outputs = await generate_with_vllm(
187187
llm,
188-
content.max_new_tokens,
188+
[generations[iter[1]].remaining_tokens for iter in iter_prompts],
189189
content.temperature,
190190
content.stop_sequences,
191191
content.return_token_log_probs,
@@ -260,7 +260,10 @@ def tool_func(text: str, past_context: Optional[str]):
260260
gen_item.generated_text += new_text
261261

262262
# If we didn't just execute a tool, we're done
263-
if not gen_item.generated_text.endswith(tool.tool_context_end):
263+
if (
264+
not gen_item.generated_text.endswith(tool.tool_context_end)
265+
or gen_item.remaining_tokens <= 0
266+
):
264267
gen_item.completed = True
265268
continue
266269

@@ -316,7 +319,7 @@ async def batch_inference():
316319

317320
outputs = await generate_with_vllm(
318321
llm,
319-
content.max_new_tokens,
322+
[content.max_new_tokens] * len(prompts),
320323
content.temperature,
321324
content.stop_sequences,
322325
content.return_token_log_probs,
@@ -358,24 +361,23 @@ async def generate_with_vllm(
358361
top_p,
359362
prompts,
360363
bar,
361-
) -> List[CompletionOutput]:
364+
) -> List[CompletionOutput]: # pragma: no cover
362365
from vllm import SamplingParams
363366

364367
# Add the requests to the engine.
365-
sampling_params = SamplingParams(
366-
max_tokens=max_new_tokens,
367-
temperature=temperature,
368-
stop=stop_sequences,
369-
logprobs=1 if return_token_log_probs else None,
370-
presence_penalty=presence_penalty or 0.0,
371-
frequency_penalty=frequency_penalty or 0.0,
372-
top_k=top_k or -1,
373-
top_p=top_p or 1.0,
374-
)
375-
376368
results_generators = []
377-
for prompt in prompts:
369+
for idx, prompt in enumerate(prompts):
378370
request_id = random_uuid()
371+
sampling_params = SamplingParams(
372+
max_tokens=max_new_tokens[idx],
373+
temperature=temperature,
374+
stop=stop_sequences,
375+
logprobs=1 if return_token_log_probs else None,
376+
presence_penalty=presence_penalty or 0.0,
377+
frequency_penalty=frequency_penalty or 0.0,
378+
top_k=top_k or -1,
379+
top_p=top_p or 1.0,
380+
)
379381
results_generator = await engine.add_request(
380382
request_id, prompt, sampling_params, None, time.monotonic()
381383
)

0 commit comments

Comments
 (0)