Skip to content

Commit 92b170a

Browse files
committed
feat: use the patch method to pass it
1 parent 098a973 commit 92b170a

File tree

3 files changed

+258
-16
lines changed

3 files changed

+258
-16
lines changed

veadk/agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,10 @@ def model_post_init(self, __context: Any) -> None:
200200

201201
if not self.model:
202202
if self.enable_responses:
203+
from veadk.utils.patches import patch_google_adk_call_llm_async
203204
from veadk.models.ark_llm import ArkLlm
204205

206+
patch_google_adk_call_llm_async()
205207
self.model = ArkLlm(
206208
model=f"{self.model_provider}/{self.model_name}",
207209
api_key=self.model_api_key,

veadk/models/ark_llm.py

Lines changed: 192 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,36 @@
1+
import json
12
import uuid
23
from datetime import datetime
3-
from typing import Any, Dict, Union
4+
from typing import Any, Dict, Union, AsyncGenerator
45

56
import litellm
7+
from google.adk.models import LlmRequest, LlmResponse
68
from google.adk.models.lite_llm import (
79
LiteLlm,
810
LiteLLMClient,
11+
_get_completion_inputs,
12+
_model_response_to_chunk,
13+
FunctionChunk,
14+
TextChunk,
15+
_message_to_generate_content_response,
16+
UsageMetadataChunk,
17+
_model_response_to_generate_content_response,
918
)
10-
from litellm import Logging
11-
from litellm import aresponses
19+
from google.genai import types
20+
from litellm import Logging, aresponses, ChatCompletionAssistantMessage
1221
from litellm.completion_extras.litellm_responses_transformation.transformation import (
1322
LiteLLMResponsesTransformationHandler,
1423
)
1524
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params
1625
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
1726
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
1827
from litellm.types.llms.openai import ResponsesAPIResponse
19-
from litellm.types.utils import ModelResponse, LlmProviders
28+
from litellm.types.utils import (
29+
ModelResponse,
30+
LlmProviders,
31+
ChatCompletionMessageToolCall,
32+
Function,
33+
)
2034
from litellm.utils import get_optional_params, ProviderConfigManager
2135
from pydantic import Field
2236

@@ -28,6 +42,16 @@
2842
logger = get_logger(__name__)
2943

3044

45+
def _add_response_id_to_llm_response(
46+
llm_response: LlmResponse, response: ModelResponse
47+
) -> LlmResponse:
48+
if not response.id.startswith("chatcmpl"):
49+
if llm_response.custom_metadata is None:
50+
llm_response.custom_metadata = {}
51+
llm_response.custom_metadata["response_id"] = response["id"]
52+
return llm_response
53+
54+
3155
class ArkLlmClient(LiteLLMClient):
3256
def __init__(self):
3357
super().__init__()
@@ -37,7 +61,7 @@ async def acompletion(
3761
self, model, messages, tools, **kwargs
3862
) -> Union[ModelResponse, CustomStreamWrapper]:
3963
# 1.1. Get optional_params using get_optional_params function
40-
optional_params = get_optional_params(model=model, **kwargs)
64+
optional_params = get_optional_params(model=model, tools=tools, **kwargs)
4165

4266
# 1.2. Get litellm_params using get_litellm_params function
4367
litellm_params = get_litellm_params(**kwargs)
@@ -74,8 +98,10 @@ async def acompletion(
7498
messages = provider_config.translate_developer_role_to_system_role(
7599
messages=messages
76100
)
77-
78-
# 1.6 Transform request to responses api format
101+
# 1.6 Add response_id to llm_response
102+
# Keep the header system-prompt and the user's messages
103+
messages = messages[:1] + messages[-1:]
104+
# 1.7 Transform request to responses api format
79105
request_data = self.transformation_handler.transform_request(
80106
model=model,
81107
messages=messages,
@@ -87,7 +113,7 @@ async def acompletion(
87113
)
88114

89115
# 2. Call litellm.aresponses with the transformed request data
90-
result = await aresponses(
116+
raw_response = await aresponses(
91117
**request_data,
92118
)
93119

@@ -96,10 +122,10 @@ async def acompletion(
96122
setattr(model_response, "usage", litellm.Usage())
97123

98124
# 3.2 Transform ResponsesAPIResponse to ModelResponses
99-
if isinstance(result, ResponsesAPIResponse):
100-
return self.transformation_handler.transform_response(
125+
if isinstance(raw_response, ResponsesAPIResponse):
126+
response = self.transformation_handler.transform_response(
101127
model=model,
102-
raw_response=result,
128+
raw_response=raw_response,
103129
model_response=model_response,
104130
logging_obj=logging_obj,
105131
request_data=request_data,
@@ -110,9 +136,14 @@ async def acompletion(
110136
api_key=kwargs.get("api_key"),
111137
json_mode=kwargs.get("json_mode"),
112138
)
139+
# 3.2.1 Modify ModelResponse id
140+
if raw_response and hasattr(raw_response, "id"):
141+
response.id = raw_response.id
142+
return response
143+
113144
else:
114145
completion_stream = self.transformation_handler.get_model_response_iterator(
115-
streaming_response=result, # type: ignore
146+
streaming_response=raw_response, # type: ignore
116147
sync_stream=True,
117148
json_mode=kwargs.get("json_mode"),
118149
)
@@ -132,7 +163,152 @@ class ArkLlm(LiteLlm):
132163
def __init__(self, **kwargs):
133164
super().__init__(**kwargs)
134165

135-
# async def generate_content_async(
136-
# self, llm_request: LlmRequest, stream: bool = False
137-
# ) -> AsyncGenerator[LlmResponse, None]:
138-
# pass
166+
async def generate_content_async(
167+
self, llm_request: LlmRequest, stream: bool = False
168+
) -> AsyncGenerator[LlmResponse, None]:
169+
"""Generates content asynchronously.
170+
171+
Args:
172+
llm_request: LlmRequest, the request to send to the LiteLlm model.
173+
stream: bool = False, whether to do streaming call.
174+
175+
Yields:
176+
LlmResponse: The model response.
177+
"""
178+
179+
self._maybe_append_user_content(llm_request)
180+
# logger.debug(_build_request_log(llm_request))
181+
182+
messages, tools, response_format, generation_params = _get_completion_inputs(
183+
llm_request
184+
)
185+
186+
if "functions" in self._additional_args:
187+
# LiteLLM does not support both tools and functions together.
188+
tools = None
189+
# ------------------------------------------------------ #
190+
# get previous_response_id
191+
previous_response_id = None
192+
if llm_request.cache_metadata and llm_request.cache_metadata.cache_name:
193+
previous_response_id = llm_request.cache_metadata.cache_name
194+
# ------------------------------------------------------ #
195+
completion_args = {
196+
"model": self.model,
197+
"messages": messages,
198+
"tools": tools,
199+
"response_format": response_format,
200+
"previous_response_id": previous_response_id, # supply previous_response_id
201+
}
202+
completion_args.update(self._additional_args)
203+
204+
if generation_params:
205+
completion_args.update(generation_params)
206+
207+
if stream:
208+
text = ""
209+
# Track function calls by index
210+
function_calls = {} # index -> {name, args, id}
211+
completion_args["stream"] = True
212+
aggregated_llm_response = None
213+
aggregated_llm_response_with_tool_call = None
214+
usage_metadata = None
215+
fallback_index = 0
216+
async for part in await self.llm_client.acompletion(**completion_args):
217+
for chunk, finish_reason in _model_response_to_chunk(part):
218+
if isinstance(chunk, FunctionChunk):
219+
index = chunk.index or fallback_index
220+
if index not in function_calls:
221+
function_calls[index] = {"name": "", "args": "", "id": None}
222+
223+
if chunk.name:
224+
function_calls[index]["name"] += chunk.name
225+
if chunk.args:
226+
function_calls[index]["args"] += chunk.args
227+
228+
# check if args is completed (workaround for improper chunk
229+
# indexing)
230+
try:
231+
json.loads(function_calls[index]["args"])
232+
fallback_index += 1
233+
except json.JSONDecodeError:
234+
pass
235+
236+
function_calls[index]["id"] = (
237+
chunk.id or function_calls[index]["id"] or str(index)
238+
)
239+
elif isinstance(chunk, TextChunk):
240+
text += chunk.text
241+
yield _message_to_generate_content_response(
242+
ChatCompletionAssistantMessage(
243+
role="assistant",
244+
content=chunk.text,
245+
),
246+
is_partial=True,
247+
)
248+
elif isinstance(chunk, UsageMetadataChunk):
249+
usage_metadata = types.GenerateContentResponseUsageMetadata(
250+
prompt_token_count=chunk.prompt_tokens,
251+
candidates_token_count=chunk.completion_tokens,
252+
total_token_count=chunk.total_tokens,
253+
)
254+
255+
if (
256+
finish_reason == "tool_calls" or finish_reason == "stop"
257+
) and function_calls:
258+
tool_calls = []
259+
for index, func_data in function_calls.items():
260+
if func_data["id"]:
261+
tool_calls.append(
262+
ChatCompletionMessageToolCall(
263+
type="function",
264+
id=func_data["id"],
265+
function=Function(
266+
name=func_data["name"],
267+
arguments=func_data["args"],
268+
index=index,
269+
),
270+
)
271+
)
272+
aggregated_llm_response_with_tool_call = (
273+
_message_to_generate_content_response(
274+
ChatCompletionAssistantMessage(
275+
role="assistant",
276+
content=text,
277+
tool_calls=tool_calls,
278+
)
279+
)
280+
)
281+
text = ""
282+
function_calls.clear()
283+
elif finish_reason == "stop" and text:
284+
aggregated_llm_response = _message_to_generate_content_response(
285+
ChatCompletionAssistantMessage(
286+
role="assistant", content=text
287+
)
288+
)
289+
text = ""
290+
291+
# waiting until streaming ends to yield the llm_response as litellm tends
292+
# to send chunk that contains usage_metadata after the chunk with
293+
# finish_reason set to tool_calls or stop.
294+
if aggregated_llm_response:
295+
if usage_metadata:
296+
aggregated_llm_response.usage_metadata = usage_metadata
297+
usage_metadata = None
298+
yield aggregated_llm_response
299+
300+
if aggregated_llm_response_with_tool_call:
301+
if usage_metadata:
302+
aggregated_llm_response_with_tool_call.usage_metadata = (
303+
usage_metadata
304+
)
305+
yield aggregated_llm_response_with_tool_call
306+
307+
else:
308+
response = await self.llm_client.acompletion(**completion_args)
309+
# ------------------------------------------------------ #
310+
# Transport response id
311+
# yield _model_response_to_generate_content_response(response)
312+
llm_response = _model_response_to_generate_content_response(response)
313+
yield _add_response_id_to_llm_response(llm_response, response)
314+
# ------------------------------------------------------ #

veadk/utils/patches.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
import sys
1717
from typing import Callable
1818

19+
from google.adk.agents import InvocationContext
20+
from google.adk.models import LlmRequest
21+
from google.adk.models.cache_metadata import CacheMetadata
22+
1923
from veadk.tracing.telemetry.telemetry import (
2024
trace_call_llm,
2125
trace_send_data,
@@ -78,3 +82,63 @@ def patch_google_adk_telemetry() -> None:
7882
logger.debug(
7983
f"Patch {mod_name} {var_name} with {trace_functions[var_name]}"
8084
)
85+
86+
87+
#
88+
# BaseLlmFlow._call_llm_async patch hook
89+
#
90+
def patch_google_adk_call_llm_async() -> None:
91+
"""Patch google.adk BaseLlmFlow._call_llm_async with a delegating wrapper.
92+
93+
Current behavior: simply calls the original implementation and yields its results.
94+
This provides a stable hook for later custom business logic without changing behavior now.
95+
"""
96+
# Prevent duplicate patches
97+
if hasattr(patch_google_adk_call_llm_async, "_patched"):
98+
logger.debug("BaseLlmFlow._call_llm_async already patched, skipping")
99+
return
100+
101+
try:
102+
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
103+
104+
original_call_llm_async = BaseLlmFlow._call_llm_async
105+
106+
async def patched_call_llm_async(
107+
self,
108+
invocation_context: InvocationContext,
109+
llm_request: LlmRequest,
110+
model_response_event,
111+
):
112+
logger.debug(
113+
"Patched BaseLlmFlow._call_llm_async invoked; delegating to original"
114+
)
115+
events = invocation_context.session.events
116+
if (
117+
events
118+
and len(events) >= 2
119+
and events[-2].custom_metadata
120+
and "response_id" in events[-2].custom_metadata
121+
):
122+
previous_response_id = events[-2].custom_metadata["response_id"]
123+
llm_request.cache_metadata = CacheMetadata(
124+
cache_name=previous_response_id,
125+
expire_time=0,
126+
fingerprint="",
127+
invocations_used=0,
128+
cached_contents_count=0,
129+
)
130+
131+
async for llm_response in original_call_llm_async(
132+
self, invocation_context, llm_request, model_response_event
133+
):
134+
# Currently, just pass through the original responses
135+
yield llm_response
136+
137+
BaseLlmFlow._call_llm_async = patched_call_llm_async
138+
139+
# Marked as patched to prevent duplicate application
140+
patch_google_adk_call_llm_async._patched = True
141+
logger.info("Successfully patched BaseLlmFlow._call_llm_async")
142+
143+
except ImportError as e:
144+
logger.warning(f"Failed to patch BaseLlmFlow._call_llm_async: {e}")

0 commit comments

Comments
 (0)