1+ import json
12import uuid
23from datetime import datetime
3- from typing import Any , Dict , Union
4+ from typing import Any , Dict , Union , AsyncGenerator
45
56import litellm
7+ from google .adk .models import LlmRequest , LlmResponse
68from google .adk .models .lite_llm import (
79 LiteLlm ,
810 LiteLLMClient ,
11+ _get_completion_inputs ,
12+ _model_response_to_chunk ,
13+ FunctionChunk ,
14+ TextChunk ,
15+ _message_to_generate_content_response ,
16+ UsageMetadataChunk ,
17+ _model_response_to_generate_content_response ,
918)
10- from litellm import Logging
11- from litellm import aresponses
19+ from google . genai import types
20+ from litellm import Logging , aresponses , ChatCompletionAssistantMessage
1221from litellm .completion_extras .litellm_responses_transformation .transformation import (
1322 LiteLLMResponsesTransformationHandler ,
1423)
1524from litellm .litellm_core_utils .get_litellm_params import get_litellm_params
1625from litellm .litellm_core_utils .get_llm_provider_logic import get_llm_provider
1726from litellm .litellm_core_utils .streaming_handler import CustomStreamWrapper
1827from litellm .types .llms .openai import ResponsesAPIResponse
19- from litellm .types .utils import ModelResponse , LlmProviders
28+ from litellm .types .utils import (
29+ ModelResponse ,
30+ LlmProviders ,
31+ ChatCompletionMessageToolCall ,
32+ Function ,
33+ )
2034from litellm .utils import get_optional_params , ProviderConfigManager
2135from pydantic import Field
2236
2842logger = get_logger (__name__ )
2943
3044
45+ def _add_response_id_to_llm_response (
46+ llm_response : LlmResponse , response : ModelResponse
47+ ) -> LlmResponse :
48+ if not response .id .startswith ("chatcmpl" ):
49+ if llm_response .custom_metadata is None :
50+ llm_response .custom_metadata = {}
51+ llm_response .custom_metadata ["response_id" ] = response ["id" ]
52+ return llm_response
53+
54+
3155class ArkLlmClient (LiteLLMClient ):
3256 def __init__ (self ):
3357 super ().__init__ ()
@@ -37,7 +61,7 @@ async def acompletion(
3761 self , model , messages , tools , ** kwargs
3862 ) -> Union [ModelResponse , CustomStreamWrapper ]:
3963 # 1.1. Get optional_params using get_optional_params function
40- optional_params = get_optional_params (model = model , ** kwargs )
64+ optional_params = get_optional_params (model = model , tools = tools , ** kwargs )
4165
4266 # 1.2. Get litellm_params using get_litellm_params function
4367 litellm_params = get_litellm_params (** kwargs )
@@ -74,8 +98,10 @@ async def acompletion(
7498 messages = provider_config .translate_developer_role_to_system_role (
7599 messages = messages
76100 )
77-
78- # 1.6 Transform request to responses api format
101+ # 1.6 Add response_id to llm_response
102+ # Keep the header system-prompt and the user's messages
103+ messages = messages [:1 ] + messages [- 1 :]
104+ # 1.7 Transform request to responses api format
79105 request_data = self .transformation_handler .transform_request (
80106 model = model ,
81107 messages = messages ,
@@ -87,7 +113,7 @@ async def acompletion(
87113 )
88114
89115 # 2. Call litellm.aresponses with the transformed request data
90- result = await aresponses (
116+ raw_response = await aresponses (
91117 ** request_data ,
92118 )
93119
@@ -96,10 +122,10 @@ async def acompletion(
96122 setattr (model_response , "usage" , litellm .Usage ())
97123
98124 # 3.2 Transform ResponsesAPIResponse to ModelResponses
99- if isinstance (result , ResponsesAPIResponse ):
100- return self .transformation_handler .transform_response (
125+ if isinstance (raw_response , ResponsesAPIResponse ):
126+ response = self .transformation_handler .transform_response (
101127 model = model ,
102- raw_response = result ,
128+ raw_response = raw_response ,
103129 model_response = model_response ,
104130 logging_obj = logging_obj ,
105131 request_data = request_data ,
@@ -110,9 +136,14 @@ async def acompletion(
110136 api_key = kwargs .get ("api_key" ),
111137 json_mode = kwargs .get ("json_mode" ),
112138 )
139+ # 3.2.1 Modify ModelResponse id
140+ if raw_response and hasattr (raw_response , "id" ):
141+ response .id = raw_response .id
142+ return response
143+
113144 else :
114145 completion_stream = self .transformation_handler .get_model_response_iterator (
115- streaming_response = result , # type: ignore
146+ streaming_response = raw_response , # type: ignore
116147 sync_stream = True ,
117148 json_mode = kwargs .get ("json_mode" ),
118149 )
@@ -132,7 +163,152 @@ class ArkLlm(LiteLlm):
132163 def __init__ (self , ** kwargs ):
133164 super ().__init__ (** kwargs )
134165
135- # async def generate_content_async(
136- # self, llm_request: LlmRequest, stream: bool = False
137- # ) -> AsyncGenerator[LlmResponse, None]:
138- # pass
166+ async def generate_content_async (
167+ self , llm_request : LlmRequest , stream : bool = False
168+ ) -> AsyncGenerator [LlmResponse , None ]:
169+ """Generates content asynchronously.
170+
171+ Args:
172+ llm_request: LlmRequest, the request to send to the LiteLlm model.
173+ stream: bool = False, whether to do streaming call.
174+
175+ Yields:
176+ LlmResponse: The model response.
177+ """
178+
179+ self ._maybe_append_user_content (llm_request )
180+ # logger.debug(_build_request_log(llm_request))
181+
182+ messages , tools , response_format , generation_params = _get_completion_inputs (
183+ llm_request
184+ )
185+
186+ if "functions" in self ._additional_args :
187+ # LiteLLM does not support both tools and functions together.
188+ tools = None
189+ # ------------------------------------------------------ #
190+ # get previous_response_id
191+ previous_response_id = None
192+ if llm_request .cache_metadata and llm_request .cache_metadata .cache_name :
193+ previous_response_id = llm_request .cache_metadata .cache_name
194+ # ------------------------------------------------------ #
195+ completion_args = {
196+ "model" : self .model ,
197+ "messages" : messages ,
198+ "tools" : tools ,
199+ "response_format" : response_format ,
200+ "previous_response_id" : previous_response_id , # supply previous_response_id
201+ }
202+ completion_args .update (self ._additional_args )
203+
204+ if generation_params :
205+ completion_args .update (generation_params )
206+
207+ if stream :
208+ text = ""
209+ # Track function calls by index
210+ function_calls = {} # index -> {name, args, id}
211+ completion_args ["stream" ] = True
212+ aggregated_llm_response = None
213+ aggregated_llm_response_with_tool_call = None
214+ usage_metadata = None
215+ fallback_index = 0
216+ async for part in await self .llm_client .acompletion (** completion_args ):
217+ for chunk , finish_reason in _model_response_to_chunk (part ):
218+ if isinstance (chunk , FunctionChunk ):
219+ index = chunk .index or fallback_index
220+ if index not in function_calls :
221+ function_calls [index ] = {"name" : "" , "args" : "" , "id" : None }
222+
223+ if chunk .name :
224+ function_calls [index ]["name" ] += chunk .name
225+ if chunk .args :
226+ function_calls [index ]["args" ] += chunk .args
227+
228+ # check if args is completed (workaround for improper chunk
229+ # indexing)
230+ try :
231+ json .loads (function_calls [index ]["args" ])
232+ fallback_index += 1
233+ except json .JSONDecodeError :
234+ pass
235+
236+ function_calls [index ]["id" ] = (
237+ chunk .id or function_calls [index ]["id" ] or str (index )
238+ )
239+ elif isinstance (chunk , TextChunk ):
240+ text += chunk .text
241+ yield _message_to_generate_content_response (
242+ ChatCompletionAssistantMessage (
243+ role = "assistant" ,
244+ content = chunk .text ,
245+ ),
246+ is_partial = True ,
247+ )
248+ elif isinstance (chunk , UsageMetadataChunk ):
249+ usage_metadata = types .GenerateContentResponseUsageMetadata (
250+ prompt_token_count = chunk .prompt_tokens ,
251+ candidates_token_count = chunk .completion_tokens ,
252+ total_token_count = chunk .total_tokens ,
253+ )
254+
255+ if (
256+ finish_reason == "tool_calls" or finish_reason == "stop"
257+ ) and function_calls :
258+ tool_calls = []
259+ for index , func_data in function_calls .items ():
260+ if func_data ["id" ]:
261+ tool_calls .append (
262+ ChatCompletionMessageToolCall (
263+ type = "function" ,
264+ id = func_data ["id" ],
265+ function = Function (
266+ name = func_data ["name" ],
267+ arguments = func_data ["args" ],
268+ index = index ,
269+ ),
270+ )
271+ )
272+ aggregated_llm_response_with_tool_call = (
273+ _message_to_generate_content_response (
274+ ChatCompletionAssistantMessage (
275+ role = "assistant" ,
276+ content = text ,
277+ tool_calls = tool_calls ,
278+ )
279+ )
280+ )
281+ text = ""
282+ function_calls .clear ()
283+ elif finish_reason == "stop" and text :
284+ aggregated_llm_response = _message_to_generate_content_response (
285+ ChatCompletionAssistantMessage (
286+ role = "assistant" , content = text
287+ )
288+ )
289+ text = ""
290+
291+ # waiting until streaming ends to yield the llm_response as litellm tends
292+ # to send chunk that contains usage_metadata after the chunk with
293+ # finish_reason set to tool_calls or stop.
294+ if aggregated_llm_response :
295+ if usage_metadata :
296+ aggregated_llm_response .usage_metadata = usage_metadata
297+ usage_metadata = None
298+ yield aggregated_llm_response
299+
300+ if aggregated_llm_response_with_tool_call :
301+ if usage_metadata :
302+ aggregated_llm_response_with_tool_call .usage_metadata = (
303+ usage_metadata
304+ )
305+ yield aggregated_llm_response_with_tool_call
306+
307+ else :
308+ response = await self .llm_client .acompletion (** completion_args )
309+ # ------------------------------------------------------ #
310+ # Transport response id
311+ # yield _model_response_to_generate_content_response(response)
312+ llm_response = _model_response_to_generate_content_response (response )
313+ yield _add_response_id_to_llm_response (llm_response , response )
314+ # ------------------------------------------------------ #
0 commit comments