Skip to content

Commit 13564de

Browse files
committed
fix: async openai
1 parent 0ab4d40 commit 13564de

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

veadk/models/ark_llm.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import litellm
77
from google.adk.agents.callback_context import CallbackContext
88
from google.adk.models.cache_metadata import CacheMetadata
9-
from openai import OpenAI
9+
from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator
10+
from openai import AsyncOpenAI
1011
from google.adk.models import LlmRequest, LlmResponse
1112
from google.adk.models.lite_llm import (
1213
LiteLlm,
@@ -96,7 +97,9 @@ def _add_response_data_to_llm_response(
9697
return llm_response
9798

9899

99-
async def openai_response_async(request_data: dict):
100+
async def openai_response_async(
101+
request_data: dict,
102+
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
100103
# Filter out fields that are not supported by OpenAI SDK
101104
filtered_request_data = {
102105
key: value
@@ -108,7 +111,7 @@ async def openai_response_async(request_data: dict):
108111
)
109112
filtered_request_data["model"] = model_name # remove custom_llm_provider
110113

111-
# Remove tools in subsequent rounds (when previous_response_id is present)
114+
# [Note: Ark Limitations] Remove tools in subsequent rounds (when previous_response_id is present)
112115
if (
113116
"tools" in filtered_request_data
114117
and "previous_response_id" in filtered_request_data
@@ -117,7 +120,7 @@ async def openai_response_async(request_data: dict):
117120
# Remove tools in subsequent rounds regardless of caching status
118121
del filtered_request_data["tools"]
119122

120-
# Ensure thinking field consistency for cache usage
123+
# [Note: Ark Limitations] Ensure thinking field consistency for cache usage
121124
if (
122125
"thinking" in filtered_request_data
123126
and "extra_body" in filtered_request_data
@@ -133,7 +136,7 @@ async def openai_response_async(request_data: dict):
133136
# Note: This is a placeholder - actual consistency check requires state tracking
134137
pass
135138

136-
# Ensure store field is true or default when caching is enabled
139+
# [Note: Ark Limitations] Ensure store field is true or default when caching is enabled
137140
if (
138141
"extra_body" in filtered_request_data
139142
and isinstance(filtered_request_data["extra_body"], dict)
@@ -160,13 +163,20 @@ async def openai_response_async(request_data: dict):
160163
}
161164
] + filtered_request_data["input"]
162165

163-
client = OpenAI(
166+
client = AsyncOpenAI(
164167
base_url=request_data["api_base"],
165168
api_key=request_data["api_key"],
166169
)
167-
openai_response = client.responses.create(**filtered_request_data)
168-
raw_response = ResponsesAPIResponse(**openai_response.model_dump())
169-
return raw_response
170+
if "stream" in filtered_request_data and filtered_request_data["stream"]:
171+
# For streaming responses, return the streaming iterator directly
172+
# stream_response = await client.responses.create(**filtered_request_data)
173+
# return stream_response
174+
raise NotImplementedError("Not implemented streaming responses api")
175+
else:
176+
# For non-streaming responses, return the ResponsesAPIResponse
177+
openai_response = await client.responses.create(**filtered_request_data)
178+
raw_response = ResponsesAPIResponse(**openai_response.model_dump())
179+
return raw_response
170180

171181

172182
class ArkLlmClient(LiteLLMClient):
@@ -194,18 +204,17 @@ async def acompletion(
194204
# [NOTE] Cannot be called directly; there is a litellm bug,
195205
# Therefore, we cannot directly call litellm.aresponses:
196206
# https://github.com/BerriAI/litellm/issues/16267
197-
# raw_response = await aresponses(
207+
# raw_response = await litellm.aresponses(
198208
# **request_data,
199209
# )
200210
raw_response = await openai_response_async(request_data)
201211

202212
# 4. Transform ResponsesAPIResponse
203-
# 4.1 Create model_response object
204-
model_response = ModelResponse()
205-
setattr(model_response, "usage", litellm.Usage())
206-
207-
# 4.2 Transform ResponsesAPIResponse to ModelResponses
213+
# 4.1 Transform ResponsesAPIResponse to ModelResponses
208214
if isinstance(raw_response, ResponsesAPIResponse):
215+
# 4.2 Create model_response object
216+
model_response = ModelResponse()
217+
setattr(model_response, "usage", litellm.Usage())
209218
response = self.transformation_handler.transform_response(
210219
model=model,
211220
raw_response=raw_response,

0 commit comments

Comments
 (0)