Skip to content

Commit ec7cb19

Browse files
WoosukKwonLiuXiaoxuanPKUsimon-moheheda12345hongxiayang
authored
[gpt-oss] Add loop for built-in tool call (#22374)
Signed-off-by: Woosuk Kwon <[email protected]> Co-authored-by: LiuXiaoxuanPKU <[email protected]> Co-authored-by: simon-mo <[email protected]> Co-authored-by: Chen Zhang <[email protected]> Co-authored-by: Hongxia Yang <[email protected]> Co-authored-by: Minseok Lee <[email protected]> Co-authored-by: Yongye Zhu <[email protected]>
1 parent 2435ea7 commit ec7cb19

File tree

2 files changed

+73
-16
lines changed

2 files changed

+73
-16
lines changed

vllm/entrypoints/openai/serving_engine.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
apply_mistral_chat_template,
3636
parse_chat_messages_futures,
3737
resolve_chat_template_content_format)
38+
from vllm.entrypoints.context import ConversationContext
3839
from vllm.entrypoints.logger import RequestLogger
3940
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
4041
ChatCompletionResponse,
@@ -948,6 +949,61 @@ async def _preprocess_chat(
948949

949950
return conversation, [request_prompt], [engine_prompt]
950951

952+
async def _generate_with_builtin_tools(
953+
self,
954+
request_id: str,
955+
request_prompt: RequestPrompt,
956+
engine_prompt: EngineTokensPrompt,
957+
sampling_params: SamplingParams,
958+
context: ConversationContext,
959+
lora_request: Optional[LoRARequest] = None,
960+
priority: int = 0,
961+
**kwargs,
962+
):
963+
orig_priority = priority
964+
while True:
965+
self._log_inputs(
966+
request_id,
967+
request_prompt,
968+
params=sampling_params,
969+
lora_request=lora_request,
970+
)
971+
generator = self.engine_client.generate(
972+
engine_prompt,
973+
sampling_params,
974+
request_id,
975+
lora_request=lora_request,
976+
priority=priority,
977+
**kwargs,
978+
)
979+
async for res in generator:
980+
context.append_output(res)
981+
# NOTE(woosuk): The stop condition is handled by the engine.
982+
yield context
983+
984+
if not context.need_builtin_tool_call():
985+
# The model did not ask for a tool call, so we're done.
986+
break
987+
988+
# Call the tool and update the context with the result.
989+
tool_output = await context.call_tool()
990+
context.append_output(tool_output)
991+
992+
# TODO: uncomment this and enable tool output streaming
993+
# yield context
994+
995+
# Create inputs for the next turn.
996+
# Render the next prompt token ids.
997+
prompt_token_ids = context.render_for_completion()
998+
engine_prompt = EngineTokensPrompt(
999+
prompt_token_ids=prompt_token_ids)
1000+
request_prompt = prompt_token_ids
1001+
# Update the sampling params.
1002+
sampling_params.max_tokens = (self.max_model_len -
1003+
len(prompt_token_ids))
1004+
# OPTIMIZATION
1005+
priority = orig_priority - 1
1006+
9511007
def _load_prompt_embeds(
9521008
self,
9531009
prompt_embeds: Optional[Union[bytes, list[bytes]]],

vllm/entrypoints/openai/serving_responses.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.engine.protocol import EngineClient
1717
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
1818
ChatTemplateContentFormatOption)
19+
from vllm.entrypoints.context import ConversationContext, SimpleContext
1920
from vllm.entrypoints.logger import RequestLogger
2021
# yapf conflicts with isort for this block
2122
# yapf: disable
@@ -29,7 +30,6 @@
2930
from vllm.entrypoints.openai.serving_engine import OpenAIServing
3031
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
3132
from vllm.logger import init_logger
32-
from vllm.outputs import RequestOutput
3333
from vllm.reasoning import ReasoningParser, ReasoningParserManager
3434
from vllm.sampling_params import SamplingParams
3535
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -187,29 +187,27 @@ async def create_responses(
187187
raw_request.state.request_metadata = request_metadata
188188

189189
# Schedule the request and get the result generator.
190-
generators: list[AsyncGenerator[RequestOutput, None]] = []
190+
generators: list[AsyncGenerator[ConversationContext, None]] = []
191191
try:
192192
for i, engine_prompt in enumerate(engine_prompts):
193193
default_max_tokens = self.max_model_len - len(
194194
engine_prompt["prompt_token_ids"])
195195
sampling_params = request.to_sampling_params(
196196
default_max_tokens, self.default_sampling_params)
197197

198-
self._log_inputs(request.request_id,
199-
request_prompts[i],
200-
params=sampling_params,
201-
lora_request=lora_request)
202-
203198
trace_headers = (None if raw_request is None else await
204199
self._get_trace_headers(raw_request.headers))
205200

206-
generator = self.engine_client.generate(
207-
engine_prompt,
208-
sampling_params,
209-
request.request_id,
201+
context = SimpleContext()
202+
generator = self._generate_with_builtin_tools(
203+
request_id=request.request_id,
204+
request_prompt=request_prompts[i],
205+
engine_prompt=engine_prompt,
206+
sampling_params=sampling_params,
207+
context=context,
210208
lora_request=lora_request,
211-
trace_headers=trace_headers,
212209
priority=request.priority,
210+
trace_headers=trace_headers,
213211
)
214212
generators.append(generator)
215213
except ValueError as e:
@@ -277,25 +275,28 @@ async def responses_full_generator(
277275
self,
278276
request: ResponsesRequest,
279277
sampling_params: SamplingParams,
280-
result_generator: AsyncIterator[RequestOutput],
278+
result_generator: AsyncIterator[ConversationContext],
281279
model_name: str,
282280
tokenizer: AnyTokenizer,
283281
request_metadata: RequestResponseMetadata,
284282
created_time: Optional[int] = None,
285283
) -> Union[ErrorResponse, ResponsesResponse]:
286284
if created_time is None:
287285
created_time = int(time.time())
288-
final_res: Optional[RequestOutput] = None
289286

287+
context: Optional[ConversationContext] = None
290288
try:
291-
async for res in result_generator:
292-
final_res = res
289+
async for context in result_generator:
290+
pass
293291
except asyncio.CancelledError:
294292
return self.create_error_response("Client disconnected")
295293
except ValueError as e:
296294
# TODO: Use a vllm-specific Validation Error
297295
return self.create_error_response(str(e))
298296

297+
assert context is not None
298+
assert isinstance(context, SimpleContext)
299+
final_res = context.last_output
299300
assert final_res is not None
300301
assert len(final_res.outputs) == 1
301302
final_output = final_res.outputs[0]

0 commit comments

Comments
 (0)