|
16 | 16 | from vllm.engine.protocol import EngineClient
|
17 | 17 | from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
18 | 18 | ChatTemplateContentFormatOption)
|
| 19 | +from vllm.entrypoints.context import ConversationContext, SimpleContext |
19 | 20 | from vllm.entrypoints.logger import RequestLogger
|
20 | 21 | # yapf conflicts with isort for this block
|
21 | 22 | # yapf: disable
|
|
29 | 30 | from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
30 | 31 | from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
31 | 32 | from vllm.logger import init_logger
|
32 |
| -from vllm.outputs import RequestOutput |
33 | 33 | from vllm.reasoning import ReasoningParser, ReasoningParserManager
|
34 | 34 | from vllm.sampling_params import SamplingParams
|
35 | 35 | from vllm.transformers_utils.tokenizer import AnyTokenizer
|
@@ -187,29 +187,27 @@ async def create_responses(
|
187 | 187 | raw_request.state.request_metadata = request_metadata
|
188 | 188 |
|
189 | 189 | # Schedule the request and get the result generator.
|
190 |
| - generators: list[AsyncGenerator[RequestOutput, None]] = [] |
| 190 | + generators: list[AsyncGenerator[ConversationContext, None]] = [] |
191 | 191 | try:
|
192 | 192 | for i, engine_prompt in enumerate(engine_prompts):
|
193 | 193 | default_max_tokens = self.max_model_len - len(
|
194 | 194 | engine_prompt["prompt_token_ids"])
|
195 | 195 | sampling_params = request.to_sampling_params(
|
196 | 196 | default_max_tokens, self.default_sampling_params)
|
197 | 197 |
|
198 |
| - self._log_inputs(request.request_id, |
199 |
| - request_prompts[i], |
200 |
| - params=sampling_params, |
201 |
| - lora_request=lora_request) |
202 |
| - |
203 | 198 | trace_headers = (None if raw_request is None else await
|
204 | 199 | self._get_trace_headers(raw_request.headers))
|
205 | 200 |
|
206 |
| - generator = self.engine_client.generate( |
207 |
| - engine_prompt, |
208 |
| - sampling_params, |
209 |
| - request.request_id, |
| 201 | + context = SimpleContext() |
| 202 | + generator = self._generate_with_builtin_tools( |
| 203 | + request_id=request.request_id, |
| 204 | + request_prompt=request_prompts[i], |
| 205 | + engine_prompt=engine_prompt, |
| 206 | + sampling_params=sampling_params, |
| 207 | + context=context, |
210 | 208 | lora_request=lora_request,
|
211 |
| - trace_headers=trace_headers, |
212 | 209 | priority=request.priority,
|
| 210 | + trace_headers=trace_headers, |
213 | 211 | )
|
214 | 212 | generators.append(generator)
|
215 | 213 | except ValueError as e:
|
@@ -277,25 +275,28 @@ async def responses_full_generator(
|
277 | 275 | self,
|
278 | 276 | request: ResponsesRequest,
|
279 | 277 | sampling_params: SamplingParams,
|
280 |
| - result_generator: AsyncIterator[RequestOutput], |
| 278 | + result_generator: AsyncIterator[ConversationContext], |
281 | 279 | model_name: str,
|
282 | 280 | tokenizer: AnyTokenizer,
|
283 | 281 | request_metadata: RequestResponseMetadata,
|
284 | 282 | created_time: Optional[int] = None,
|
285 | 283 | ) -> Union[ErrorResponse, ResponsesResponse]:
|
286 | 284 | if created_time is None:
|
287 | 285 | created_time = int(time.time())
|
288 |
| - final_res: Optional[RequestOutput] = None |
289 | 286 |
|
| 287 | + context: Optional[ConversationContext] = None |
290 | 288 | try:
|
291 |
| - async for res in result_generator: |
292 |
| - final_res = res |
| 289 | + async for context in result_generator: |
| 290 | + pass |
293 | 291 | except asyncio.CancelledError:
|
294 | 292 | return self.create_error_response("Client disconnected")
|
295 | 293 | except ValueError as e:
|
296 | 294 | # TODO: Use a vllm-specific Validation Error
|
297 | 295 | return self.create_error_response(str(e))
|
298 | 296 |
|
| 297 | + assert context is not None |
| 298 | + assert isinstance(context, SimpleContext) |
| 299 | + final_res = context.last_output |
299 | 300 | assert final_res is not None
|
300 | 301 | assert len(final_res.outputs) == 1
|
301 | 302 | final_output = final_res.outputs[0]
|
|
0 commit comments