Skip to content

Commit 15fb480

Browse files
committed
feat: use callback instead patch
1 parent 3603050 commit 15fb480

File tree

3 files changed

+41
-68
lines changed

3 files changed

+41
-68
lines changed

veadk/agent.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from veadk.knowledgebase import KnowledgeBase
3737
from veadk.memory.long_term_memory import LongTermMemory
3838
from veadk.memory.short_term_memory import ShortTermMemory
39+
from veadk.models.ark_llm import add_previous_response_id
3940
from veadk.processors import BaseRunProcessor, NoOpRunProcessor
4041
from veadk.prompts.agent_default_prompt import DEFAULT_DESCRIPTION, DEFAULT_INSTRUCTION
4142
from veadk.tracing.base_tracer import BaseTracer
@@ -200,16 +201,20 @@ def model_post_init(self, __context: Any) -> None:
200201

201202
if not self.model:
202203
if self.enable_responses:
203-
from veadk.utils.patches import patch_google_adk_call_llm_async
204+
# from veadk.utils.patches import patch_google_adk_call_llm_async
204205
from veadk.models.ark_llm import ArkLlm
205206

206-
patch_google_adk_call_llm_async()
207+
# patch_google_adk_call_llm_async()
207208
self.model = ArkLlm(
208209
model=f"{self.model_provider}/{self.model_name}",
209210
api_key=self.model_api_key,
210211
api_base=self.model_api_base,
211212
**self.model_extra_config,
212213
)
214+
if not self.before_model_callback:
215+
self.before_model_callback = add_previous_response_id
216+
else:
217+
self.before_model_callback.append(add_previous_response_id)
213218
else:
214219
self.model = LiteLlm(
215220
model=f"{self.model_provider}/{self.model_name}",

veadk/models/ark_llm.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import json
22
import uuid
33
from datetime import datetime
4-
from typing import Any, Dict, Union, AsyncGenerator
4+
from typing import Any, Dict, Union, AsyncGenerator, Optional
55

66
import litellm
7+
from google.adk.agents.callback_context import CallbackContext
8+
from google.adk.models.cache_metadata import CacheMetadata
79
from openai import OpenAI
810
from google.adk.models import LlmRequest, LlmResponse
911
from google.adk.models.lite_llm import (
@@ -77,13 +79,20 @@
7779
]
7880

7981

80-
def _add_response_id_to_llm_response(
82+
def _add_response_data_to_llm_response(
8183
llm_response: LlmResponse, response: ModelResponse
8284
) -> LlmResponse:
85+
# add responses id
8386
if not response.id.startswith("chatcmpl"):
8487
if llm_response.custom_metadata is None:
8588
llm_response.custom_metadata = {}
8689
llm_response.custom_metadata["response_id"] = response["id"]
90+
# add responses cache data
91+
if response.get("usage", {}).get("prompt_tokens_details"):
92+
if llm_response.usage_metadata:
93+
llm_response.usage_metadata.cached_content_token_count = (
94+
response.get("usage", {}).get("prompt_tokens_details").cached_tokens
95+
)
8796
return llm_response
8897

8998

@@ -423,5 +432,27 @@ async def generate_content_async(
423432
# Transport response id
424433
# yield _model_response_to_generate_content_response(response)
425434
llm_response = _model_response_to_generate_content_response(response)
426-
yield _add_response_id_to_llm_response(llm_response, response)
435+
yield _add_response_data_to_llm_response(llm_response, response)
427436
# ------------------------------------------------------ #
437+
438+
439+
def add_previous_response_id(
440+
callback_context: CallbackContext, llm_request: LlmRequest
441+
) -> Optional[LlmResponse]:
442+
invocation_context = callback_context._invocation_context
443+
events = invocation_context.session.events
444+
if (
445+
events
446+
and len(events) >= 2
447+
and events[-2].custom_metadata
448+
and "response_id" in events[-2].custom_metadata
449+
):
450+
previous_response_id = events[-2].custom_metadata["response_id"]
451+
llm_request.cache_metadata = CacheMetadata(
452+
cache_name=previous_response_id,
453+
expire_time=0,
454+
fingerprint="",
455+
invocations_used=0,
456+
cached_contents_count=0,
457+
)
458+
return

veadk/utils/patches.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@
1616
import sys
1717
from typing import Callable
1818

19-
from google.adk.agents import InvocationContext
20-
from google.adk.models import LlmRequest
21-
from google.adk.models.cache_metadata import CacheMetadata
2219

2320
from veadk.tracing.telemetry.telemetry import (
2421
trace_call_llm,
@@ -82,63 +79,3 @@ def patch_google_adk_telemetry() -> None:
8279
logger.debug(
8380
f"Patch {mod_name} {var_name} with {trace_functions[var_name]}"
8481
)
85-
86-
87-
#
88-
# BaseLlmFlow._call_llm_async patch hook
89-
#
90-
def patch_google_adk_call_llm_async() -> None:
91-
"""Patch google.adk BaseLlmFlow._call_llm_async with a delegating wrapper.
92-
93-
Current behavior: simply calls the original implementation and yields its results.
94-
This provides a stable hook for later custom business logic without changing behavior now.
95-
"""
96-
# Prevent duplicate patches
97-
if hasattr(patch_google_adk_call_llm_async, "_patched"):
98-
logger.debug("BaseLlmFlow._call_llm_async already patched, skipping")
99-
return
100-
101-
try:
102-
from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow
103-
104-
original_call_llm_async = BaseLlmFlow._call_llm_async
105-
106-
async def patched_call_llm_async(
107-
self,
108-
invocation_context: InvocationContext,
109-
llm_request: LlmRequest,
110-
model_response_event,
111-
):
112-
logger.debug(
113-
"Patched BaseLlmFlow._call_llm_async invoked; delegating to original"
114-
)
115-
events = invocation_context.session.events
116-
if (
117-
events
118-
and len(events) >= 2
119-
and events[-2].custom_metadata
120-
and "response_id" in events[-2].custom_metadata
121-
):
122-
previous_response_id = events[-2].custom_metadata["response_id"]
123-
llm_request.cache_metadata = CacheMetadata(
124-
cache_name=previous_response_id,
125-
expire_time=0,
126-
fingerprint="",
127-
invocations_used=0,
128-
cached_contents_count=0,
129-
)
130-
131-
async for llm_response in original_call_llm_async(
132-
self, invocation_context, llm_request, model_response_event
133-
):
134-
# Currently, just pass through the original responses
135-
yield llm_response
136-
137-
BaseLlmFlow._call_llm_async = patched_call_llm_async
138-
139-
# Marked as patched to prevent duplicate application
140-
patch_google_adk_call_llm_async._patched = True
141-
logger.info("Successfully patched BaseLlmFlow._call_llm_async")
142-
143-
except ImportError as e:
144-
logger.warning(f"Failed to patch BaseLlmFlow._call_llm_async: {e}")

0 commit comments

Comments
 (0)