Skip to content

Commit 3603050

Browse files
committed
feat: openai sdk for responses api
1 parent d2b75fe commit 3603050

File tree

1 file changed

+89
-4
lines changed

1 file changed

+89
-4
lines changed

veadk/models/ark_llm.py

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, Union, AsyncGenerator
55

66
import litellm
7+
from openai import OpenAI
78
from google.adk.models import LlmRequest, LlmResponse
89
from google.adk.models.lite_llm import (
910
LiteLlm,
@@ -17,7 +18,7 @@
1718
_model_response_to_generate_content_response,
1819
)
1920
from google.genai import types
20-
from litellm import Logging, aresponses, ChatCompletionAssistantMessage
21+
from litellm import Logging, ChatCompletionAssistantMessage
2122
from litellm.completion_extras.litellm_responses_transformation.transformation import (
2223
LiteLLMResponsesTransformationHandler,
2324
)
@@ -41,6 +42,40 @@
4142

4243
logger = get_logger(__name__)
4344

45+
openai_supported_fields = [
46+
"stream",
47+
"background",
48+
"include",
49+
"input",
50+
"instructions",
51+
"max_output_tokens",
52+
"max_tool_calls",
53+
"metadata",
54+
"model",
55+
"parallel_tool_calls",
56+
"previous_response_id",
57+
"prompt",
58+
"prompt_cache_key",
59+
"reasoning",
60+
"safety_identifier",
61+
"service_tier",
62+
"store",
63+
"stream",
64+
"stream_options",
65+
"temperature",
66+
"text",
67+
"tool_choice",
68+
"tools",
69+
"top_logprobs",
70+
"top_p",
71+
"truncation",
72+
"user",
73+
"extra_headers",
74+
"extra_query",
75+
"extra_body",
76+
"timeout",
77+
]
78+
4479

4580
def _add_response_id_to_llm_response(
4681
llm_response: LlmResponse, response: ModelResponse
@@ -52,6 +87,52 @@ def _add_response_id_to_llm_response(
5287
return llm_response
5388

5489

90+
async def openai_response_async(request_data: dict):
91+
# Filter out fields that are not supported by OpenAI SDK
92+
filtered_request_data = {
93+
key: value
94+
for key, value in request_data.items()
95+
if key in openai_supported_fields and value is not None
96+
}
97+
model_name, custom_llm_provider, _, _ = get_llm_provider(
98+
model=request_data["model"]
99+
)
100+
filtered_request_data["model"] = model_name # remove custom_llm_provider
101+
102+
if (
103+
"tools" in filtered_request_data
104+
and "extra_body" in filtered_request_data
105+
and isinstance(filtered_request_data["extra_body"], dict)
106+
and "caching" in filtered_request_data["extra_body"]
107+
and isinstance(filtered_request_data["extra_body"]["caching"], dict)
108+
and filtered_request_data["extra_body"]["caching"].get("type") == "enabled"
109+
and "previous_response_id" in filtered_request_data
110+
and filtered_request_data["previous_response_id"] is not None
111+
):
112+
# Remove tools when caching is enabled and previous_response_id is present
113+
del filtered_request_data["tools"]
114+
115+
# Remove instructions when caching is enabled with specific configuration
116+
if (
117+
"instructions" in filtered_request_data
118+
and "extra_body" in filtered_request_data
119+
and isinstance(filtered_request_data["extra_body"], dict)
120+
and "caching" in filtered_request_data["extra_body"]
121+
and isinstance(filtered_request_data["extra_body"]["caching"], dict)
122+
and filtered_request_data["extra_body"]["caching"].get("type") == "enabled"
123+
):
124+
# Remove instructions when caching is enabled
125+
del filtered_request_data["instructions"]
126+
127+
client = OpenAI(
128+
base_url=request_data["api_base"],
129+
api_key=request_data["api_key"],
130+
)
131+
openai_response = client.responses.create(**filtered_request_data)
132+
raw_response = ResponsesAPIResponse(**openai_response.model_dump())
133+
return raw_response
134+
135+
55136
class ArkLlmClient(LiteLLMClient):
56137
def __init__(self):
57138
super().__init__()
@@ -74,9 +155,13 @@ async def acompletion(
74155
) = self._get_request_data(model, messages, tools, **kwargs)
75156

76157
# 3. Call litellm.aresponses with the transformed request data
77-
raw_response = await aresponses(
78-
**request_data,
79-
)
158+
# Cannot be called directly; there is a litellm bug :
159+
# https://github.com/BerriAI/litellm/issues/16267
160+
# raw_response = await aresponses(
161+
# **request_data,
162+
# )
163+
raw_response = await openai_response_async(request_data)
164+
80165
# 4. Transform ResponsesAPIResponse
81166
# 4.1 Create model_response object
82167
model_response = ModelResponse()

0 commit comments

Comments
 (0)