44from typing import Any , Dict , Union , AsyncGenerator
55
66import litellm
7+ from openai import OpenAI
78from google .adk .models import LlmRequest , LlmResponse
89from google .adk .models .lite_llm import (
910 LiteLlm ,
1718 _model_response_to_generate_content_response ,
1819)
1920from google .genai import types
20- from litellm import Logging , aresponses , ChatCompletionAssistantMessage
21+ from litellm import Logging , ChatCompletionAssistantMessage
2122from litellm .completion_extras .litellm_responses_transformation .transformation import (
2223 LiteLLMResponsesTransformationHandler ,
2324)
4142
4243logger = get_logger (__name__ )
4344
45+ openai_supported_fields = [
46+ "stream" ,
47+ "background" ,
48+ "include" ,
49+ "input" ,
50+ "instructions" ,
51+ "max_output_tokens" ,
52+ "max_tool_calls" ,
53+ "metadata" ,
54+ "model" ,
55+ "parallel_tool_calls" ,
56+ "previous_response_id" ,
57+ "prompt" ,
58+ "prompt_cache_key" ,
59+ "reasoning" ,
60+ "safety_identifier" ,
61+ "service_tier" ,
62+ "store" ,
63+ "stream" ,
64+ "stream_options" ,
65+ "temperature" ,
66+ "text" ,
67+ "tool_choice" ,
68+ "tools" ,
69+ "top_logprobs" ,
70+ "top_p" ,
71+ "truncation" ,
72+ "user" ,
73+ "extra_headers" ,
74+ "extra_query" ,
75+ "extra_body" ,
76+ "timeout" ,
77+ ]
78+
4479
4580def _add_response_id_to_llm_response (
4681 llm_response : LlmResponse , response : ModelResponse
@@ -52,6 +87,52 @@ def _add_response_id_to_llm_response(
5287 return llm_response
5388
5489
90+ async def openai_response_async (request_data : dict ):
91+ # Filter out fields that are not supported by OpenAI SDK
92+ filtered_request_data = {
93+ key : value
94+ for key , value in request_data .items ()
95+ if key in openai_supported_fields and value is not None
96+ }
97+ model_name , custom_llm_provider , _ , _ = get_llm_provider (
98+ model = request_data ["model" ]
99+ )
100+ filtered_request_data ["model" ] = model_name # remove custom_llm_provider
101+
102+ if (
103+ "tools" in filtered_request_data
104+ and "extra_body" in filtered_request_data
105+ and isinstance (filtered_request_data ["extra_body" ], dict )
106+ and "caching" in filtered_request_data ["extra_body" ]
107+ and isinstance (filtered_request_data ["extra_body" ]["caching" ], dict )
108+ and filtered_request_data ["extra_body" ]["caching" ].get ("type" ) == "enabled"
109+ and "previous_response_id" in filtered_request_data
110+ and filtered_request_data ["previous_response_id" ] is not None
111+ ):
112+ # Remove tools when caching is enabled and previous_response_id is present
113+ del filtered_request_data ["tools" ]
114+
115+ # Remove instructions when caching is enabled with specific configuration
116+ if (
117+ "instructions" in filtered_request_data
118+ and "extra_body" in filtered_request_data
119+ and isinstance (filtered_request_data ["extra_body" ], dict )
120+ and "caching" in filtered_request_data ["extra_body" ]
121+ and isinstance (filtered_request_data ["extra_body" ]["caching" ], dict )
122+ and filtered_request_data ["extra_body" ]["caching" ].get ("type" ) == "enabled"
123+ ):
124+ # Remove instructions when caching is enabled
125+ del filtered_request_data ["instructions" ]
126+
127+ client = OpenAI (
128+ base_url = request_data ["api_base" ],
129+ api_key = request_data ["api_key" ],
130+ )
131+ openai_response = client .responses .create (** filtered_request_data )
132+ raw_response = ResponsesAPIResponse (** openai_response .model_dump ())
133+ return raw_response
134+
135+
55136class ArkLlmClient (LiteLLMClient ):
56137 def __init__ (self ):
57138 super ().__init__ ()
@@ -74,9 +155,13 @@ async def acompletion(
74155 ) = self ._get_request_data (model , messages , tools , ** kwargs )
75156
76157 # 3. Call litellm.aresponses with the transformed request data
77- raw_response = await aresponses (
78- ** request_data ,
79- )
158+ # Cannot be called directly; there is a litellm bug :
159+ # https://github.com/BerriAI/litellm/issues/16267
160+ # raw_response = await aresponses(
161+ # **request_data,
162+ # )
163+ raw_response = await openai_response_async (request_data )
164+
80165 # 4. Transform ResponsesAPIResponse
81166 # 4.1 Create model_response object
82167 model_response = ModelResponse ()
0 commit comments