Skip to content

Commit bd6658d

Browse files
committed
fix: multi-agent and multi llm_response scenario
1 parent 45a6070 commit bd6658d

File tree

2 files changed

+98
-38
lines changed

2 files changed

+98
-38
lines changed

veadk/models/ark_llm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,12 @@ async def generate_content_async(
245245

246246
else:
247247
raw_response = await self.llm_client.aresponse(**response_args)
248-
yield self.transform_handler.openai_response_to_generate_content_response(
248+
for (
249+
llm_response
250+
) in self.transform_handler.openai_response_to_generate_content_response(
249251
raw_response
250-
)
252+
):
253+
yield llm_response
251254

252255

253256
# before_model_callback

veadk/models/ark_transform.py

Lines changed: 93 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
Response as OpenAITypeResponse,
3131
ResponseStreamEvent,
3232
ResponseTextDeltaEvent,
33+
ResponseOutputMessage,
34+
ResponseFunctionToolCall,
3335
)
3436
from openai.types.responses import (
3537
ResponseCompletedEvent,
@@ -144,7 +146,21 @@ def __init__(self):
144146
def transform_request(
145147
self, model: str, messages: list, tools: Optional[list], **kwargs
146148
):
147-
messages = messages[:1] + messages[-1:]
149+
# Keep the first message and all consecutive user messages from the end
150+
filtered_messages = messages[:1]
151+
152+
# Collect all consecutive user messages from the end
153+
user_messages_from_end = []
154+
for message in reversed(messages[1:]): # Skip the first message
155+
if message.get("role") and message.get("role") in {"user", "tool"}:
156+
user_messages_from_end.append(message)
157+
else:
158+
break # Stop when we encounter a non-user message
159+
160+
# Reverse to maintain original order and add to filtered messages
161+
filtered_messages.extend(reversed(user_messages_from_end))
162+
163+
messages = filtered_messages
148164
# completion_request to responses api request
149165
# 1. model and llm_custom_provider
150166
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
@@ -190,45 +206,51 @@ def transform_request(
190206

191207
def transform_response(
192208
self, openai_response: OpenAITypeResponse, stream: bool = False
193-
) -> ModelResponse:
209+
) -> list[ModelResponse]:
194210
# openai_type_response -> responses_api_response -> completion_response
195-
raw_response = ResponsesAPIResponse(**openai_response.model_dump())
196-
197-
model_response = ModelResponse(stream=stream)
198-
setattr(model_response, "usage", litellm.Usage())
199-
response = self.litellm_handler.transform_response(
200-
model=raw_response.model,
201-
raw_response=raw_response,
202-
model_response=model_response,
203-
logging_obj=None,
204-
request_data={},
205-
messages=[],
206-
optional_params={},
207-
litellm_params={},
208-
encoding=None,
209-
)
210-
if raw_response and hasattr(raw_response, "id"):
211-
response.id = raw_response.id
212-
return response
211+
result_list = []
212+
raw_response_list = construct_responses_api_response(openai_response)
213+
for raw_response in raw_response_list:
214+
model_response = ModelResponse(stream=stream)
215+
setattr(model_response, "usage", litellm.Usage())
216+
response = self.litellm_handler.transform_response(
217+
model=raw_response.model,
218+
raw_response=raw_response,
219+
model_response=model_response,
220+
logging_obj=None,
221+
request_data={},
222+
messages=[],
223+
optional_params={},
224+
litellm_params={},
225+
encoding=None,
226+
)
227+
if raw_response and hasattr(raw_response, "id"):
228+
response.id = raw_response.id
229+
result_list.append(response)
230+
231+
return result_list
213232

214233
def openai_response_to_generate_content_response(
215234
self, raw_response: OpenAITypeResponse
216-
) -> LlmResponse:
235+
) -> list[LlmResponse]:
217236
"""
218237
OpenAITypeResponse -> litellm.ModelResponse -> LlmResponse
219238
instead of `_model_response_to_generate_content_response`,
220239
"""
221240
# no stream response
222-
model_response = self.transform_response(
241+
model_response_list = self.transform_response(
223242
openai_response=raw_response, stream=False
224243
)
225-
llm_response = _model_response_to_generate_content_response(model_response)
244+
llm_response_list = []
245+
for model_response in model_response_list:
246+
llm_response = _model_response_to_generate_content_response(model_response)
226247

227-
llm_response = self.adapt_responses_api(
228-
model_response,
229-
llm_response,
230-
)
231-
return llm_response
248+
llm_response = self.adapt_responses_api(
249+
model_response,
250+
llm_response,
251+
)
252+
llm_response_list.append(llm_response)
253+
return llm_response_list
232254

233255
def adapt_responses_api(
234256
self,
@@ -284,14 +306,15 @@ def stream_event_to_chunk(
284306
yield model_response, chunk, None
285307
elif isinstance(event, ResponseCompletedEvent):
286308
response = event.response
287-
model_response = self.transform_response(response, stream=True)
288-
model_response = fix_model_response(model_response)
289-
290-
for chunk, finish_reason in _model_response_to_chunk(model_response):
291-
if isinstance(chunk, TextChunk):
292-
yield model_response, None, finish_reason
293-
else:
294-
yield model_response, chunk, finish_reason
309+
model_response_list = self.transform_response(response, stream=True)
310+
for model_response in model_response_list:
311+
model_response = fix_model_response(model_response)
312+
313+
for chunk, finish_reason in _model_response_to_chunk(model_response):
314+
if isinstance(chunk, TextChunk):
315+
yield model_response, None, finish_reason
316+
else:
317+
yield model_response, chunk, finish_reason
295318
else:
296319
# Ignore other event types like ResponseOutputItemAddedEvent, etc.
297320
pass
@@ -308,3 +331,37 @@ def fix_model_response(model_response: ModelResponse) -> ModelResponse:
308331
model_response.choices[i].message.tool_calls[idx].index = 0
309332

310333
return model_response
334+
335+
336+
def construct_responses_api_response(
337+
openai_response: OpenAITypeResponse,
338+
) -> list[ResponsesAPIResponse]:
339+
output = openai_response.output
340+
341+
# Check if we need to split the response
342+
if len(output) >= 2:
343+
# Check if output contains both ResponseOutputMessage and ResponseFunctionToolCall types
344+
has_message = any(isinstance(item, ResponseOutputMessage) for item in output)
345+
has_tool_call = any(
346+
isinstance(item, ResponseFunctionToolCall) for item in output
347+
)
348+
349+
if has_message and has_tool_call:
350+
# Split into separate responses for each item
351+
raw_response_list = []
352+
for item in output:
353+
if isinstance(item, (ResponseOutputMessage, ResponseFunctionToolCall)):
354+
raw_response_list.append(
355+
ResponsesAPIResponse(
356+
**{
357+
k: v
358+
for k, v in openai_response.model_dump().items()
359+
if k != "output"
360+
},
361+
output=[item],
362+
)
363+
)
364+
return raw_response_list
365+
366+
# Otherwise, return the original response structure
367+
return [ResponsesAPIResponse(**openai_response.model_dump())]

0 commit comments

Comments
 (0)