Skip to content

Commit b453d48

Browse files
committed
fix: remove before_model_callback
1 parent 894d4c3 commit b453d48

File tree

3 files changed

+116
-65
lines changed

3 files changed

+116
-65
lines changed

veadk/agent.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919

2020
from google.adk.agents import LlmAgent, RunConfig, InvocationContext
2121
from google.adk.agents.base_agent import BaseAgent
22+
from google.adk.agents.context_cache_config import ContextCacheConfig
2223
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
2324
from google.adk.agents.run_config import StreamingMode
24-
from google.adk.events import Event, EventActions
25+
from google.adk.events import Event
2526
from google.adk.models.lite_llm import LiteLlm
2627
from google.adk.runners import Runner
2728
from google.genai import types
@@ -37,7 +38,6 @@
3738
from veadk.knowledgebase import KnowledgeBase
3839
from veadk.memory.long_term_memory import LongTermMemory
3940
from veadk.memory.short_term_memory import ShortTermMemory
40-
from veadk.models.ark_llm import add_previous_response_id
4141
from veadk.processors import BaseRunProcessor, NoOpRunProcessor
4242
from veadk.prompts.agent_default_prompt import DEFAULT_DESCRIPTION, DEFAULT_INSTRUCTION
4343
from veadk.tracing.base_tracer import BaseTracer
@@ -155,6 +155,8 @@ class Agent(LlmAgent):
155155

156156
enable_responses: bool = False
157157

158+
context_cache_config: Optional[ContextCacheConfig] = None
159+
158160
run_processor: Optional[BaseRunProcessor] = Field(default=None, exclude=True)
159161
"""Optional run processor for intercepting and processing agent execution flows.
160162
@@ -210,16 +212,12 @@ def model_post_init(self, __context: Any) -> None:
210212
api_base=self.model_api_base,
211213
**self.model_extra_config,
212214
)
213-
if not self.before_model_callback:
214-
self.before_model_callback = add_previous_response_id
215-
else:
216-
if isinstance(self.before_model_callback, list):
217-
self.before_model_callback.append(add_previous_response_id)
218-
else:
219-
self.before_model_callback = [
220-
self.before_model_callback,
221-
add_previous_response_id,
222-
]
215+
if not self.context_cache_config:
216+
self.context_cache_config = ContextCacheConfig(
217+
cache_intervals=100, # maximum number
218+
ttl_seconds=315360000,
219+
min_tokens=0,
220+
)
223221
else:
224222
self.model = LiteLlm(
225223
model=f"{self.model_provider}/{self.model_name}",
@@ -265,22 +263,11 @@ def model_post_init(self, __context: Any) -> None:
265263
async def _run_async_impl(
266264
self, ctx: InvocationContext
267265
) -> AsyncGenerator[Event, None]:
266+
if self.enable_responses:
267+
if not ctx.context_cache_config:
268+
ctx.context_cache_config = self.context_cache_config
269+
268270
async for event in super()._run_async_impl(ctx):
269-
agent_name = self.name
270-
if (
271-
self.enable_responses
272-
and event.custom_metadata
273-
and event.custom_metadata.get("response_id")
274-
):
275-
response_id = event.custom_metadata["response_id"]
276-
yield Event(
277-
invocation_id=ctx.invocation_id,
278-
author=self.name,
279-
actions=EventActions(
280-
state_delta={f"agent:{agent_name}:response_id": response_id}
281-
),
282-
branch=ctx.branch,
283-
)
284271
yield event
285272

286273
async def _run(

veadk/models/ark_llm.py

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515
# adapted from Google ADK models adk-python/blob/main/src/google/adk/models/lite_llm.py at f1f44675e4a86b75e72cfd838efd8a0399f23e24 · google/adk-python
1616

1717
import json
18-
from typing import Any, Dict, Union, AsyncGenerator, Optional
18+
from typing import Any, Dict, Union, AsyncGenerator
1919

2020
import litellm
2121
import openai
2222
from openai.types.responses import Response as OpenAITypeResponse, ResponseStreamEvent
23-
from google.adk.agents.callback_context import CallbackContext
24-
from google.adk.models.cache_metadata import CacheMetadata
2523
from google.adk.models import LlmRequest, LlmResponse
2624
from google.adk.models.lite_llm import (
2725
LiteLlm,
@@ -41,6 +39,7 @@
4139

4240
from veadk.models.ark_transform import (
4341
CompletionToResponsesAPIHandler,
42+
get_previous_response_id,
4443
)
4544
from veadk.utils.logger import get_logger
4645

@@ -90,7 +89,7 @@ async def generate_content_async(
9089
Yields:
9190
LlmResponse: The model response.
9291
"""
93-
92+
agent_name = llm_request.config.labels["adk_agent_name"]
9493
self._maybe_append_user_content(llm_request)
9594
# logger.debug(_build_request_log(llm_request))
9695

@@ -105,7 +104,10 @@ async def generate_content_async(
105104
# get previous_response_id
106105
previous_response_id = None
107106
if llm_request.cache_metadata and llm_request.cache_metadata.cache_name:
108-
previous_response_id = llm_request.cache_metadata.cache_name
107+
previous_response_id = get_previous_response_id(
108+
llm_request.cache_metadata,
109+
agent_name,
110+
)
109111
completion_args = {
110112
"model": self.model,
111113
"messages": messages,
@@ -210,6 +212,7 @@ async def generate_content_async(
210212
)
211213
)
212214
self.transform_handler.adapt_responses_api(
215+
llm_request,
213216
model_response,
214217
aggregated_llm_response_with_tool_call,
215218
stream=True,
@@ -223,7 +226,10 @@ async def generate_content_async(
223226
)
224227
)
225228
self.transform_handler.adapt_responses_api(
226-
model_response, aggregated_llm_response, stream=True
229+
llm_request,
230+
model_response,
231+
aggregated_llm_response,
232+
stream=True,
227233
)
228234
text = ""
229235

@@ -248,32 +254,6 @@ async def generate_content_async(
248254
for (
249255
llm_response
250256
) in self.transform_handler.openai_response_to_generate_content_response(
251-
raw_response
257+
llm_request, raw_response
252258
):
253259
yield llm_response
254-
255-
256-
# before_model_callback
257-
def add_previous_response_id(
258-
callback_context: CallbackContext, llm_request: LlmRequest
259-
) -> Optional[LlmResponse]:
260-
agent_name = callback_context.agent_name
261-
# read response_id
262-
previous_response_id = callback_context.state.get(f"agent:{agent_name}:response_id")
263-
if "contents_count" in CacheMetadata.model_fields: # adk >= 1.17
264-
llm_request.cache_metadata = CacheMetadata(
265-
cache_name=previous_response_id,
266-
expire_time=0,
267-
fingerprint="",
268-
invocations_used=0,
269-
contents_count=0,
270-
)
271-
else: # 1.15 <= adk < 1.17
272-
llm_request.cache_metadata = CacheMetadata(
273-
cache_name=previous_response_id,
274-
expire_time=0,
275-
fingerprint="",
276-
invocations_used=0,
277-
cached_contents_count=0,
278-
)
279-
return

veadk/models/ark_transform.py

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414

1515
# adapted from Google ADK models adk-python/blob/main/src/google/adk/models/lite_llm.py at f1f44675e4a86b75e72cfd838efd8a0399f23e24 · google/adk-python
1616

17+
import json
1718
import uuid
1819
from typing import Any, Dict, Optional, cast, List, Generator, Tuple, Union
1920

2021
import litellm
21-
from google.adk.models import LlmResponse
22+
from google.adk.models import LlmResponse, LlmRequest
23+
from google.adk.models.cache_metadata import CacheMetadata
2224
from google.adk.models.lite_llm import (
2325
TextChunk,
2426
FunctionChunk,
@@ -139,6 +141,74 @@ def ark_field_reorganization(request_data: dict) -> dict:
139141
return request_data
140142

141143

144+
def build_cache_metadata(agent_response_id: dict) -> CacheMetadata:
145+
"""Create a new CacheMetadata instance for agent response tracking.
146+
147+
Args:
148+
agent_name: Name of the agent
149+
response_id: Response ID to track
150+
151+
Returns:
152+
A new CacheMetadata instance with the agent-response mapping
153+
"""
154+
cache_name = json.dumps(agent_response_id)
155+
if "contents_count" in CacheMetadata.model_fields: # adk >= 1.17
156+
cache_metadata = CacheMetadata(
157+
cache_name=cache_name,
158+
expire_time=0,
159+
fingerprint="",
160+
invocations_used=0,
161+
contents_count=0,
162+
)
163+
else: # 1.15 <= adk < 1.17
164+
cache_metadata = CacheMetadata(
165+
cache_name=cache_name,
166+
expire_time=0,
167+
fingerprint="",
168+
invocations_used=0,
169+
cached_contents_count=0,
170+
)
171+
return cache_metadata
172+
173+
174+
def update_cache_metadata(
175+
cache_metadata: CacheMetadata,
176+
agent_name: str,
177+
response_id: str,
178+
) -> CacheMetadata:
179+
"""Update cache metadata by creating a new instance with updated cache_name.
180+
181+
Since CacheMetadata is frozen, we cannot modify it directly. Instead,
182+
we create a new instance with the updated cache_name field.
183+
"""
184+
try:
185+
agent_response_id = json.loads(cache_metadata.cache_name)
186+
agent_response_id[agent_name] = response_id
187+
updated_cache_name = agent_response_id
188+
189+
# Create a new CacheMetadata instance with updated cache_name
190+
return build_cache_metadata(updated_cache_name)
191+
except json.JSONDecodeError as e:
192+
logger.warning(
193+
f"Failed to update cache metadata. The cache_name is not a valid JSON string., {str(e)}"
194+
)
195+
return cache_metadata
196+
197+
198+
def get_previous_response_id(
199+
cache_metadata: CacheMetadata,
200+
agent_name: str,
201+
):
202+
try:
203+
agent_response_id = json.loads(cache_metadata.cache_name)
204+
return agent_response_id.get(agent_name, None)
205+
except json.JSONDecodeError as e:
206+
logger.warning(
207+
f"Failed to get previous response id. The cache_name is not a valid JSON string., {str(e)}"
208+
)
209+
return None
210+
211+
142212
class CompletionToResponsesAPIHandler:
143213
def __init__(self):
144214
self.litellm_handler = LiteLLMResponsesTransformationHandler()
@@ -231,7 +301,7 @@ def transform_response(
231301
return result_list
232302

233303
def openai_response_to_generate_content_response(
234-
self, raw_response: OpenAITypeResponse
304+
self, llm_request: LlmRequest, raw_response: OpenAITypeResponse
235305
) -> list[LlmResponse]:
236306
"""
237307
OpenAITypeResponse -> litellm.ModelResponse -> LlmResponse
@@ -246,6 +316,7 @@ def openai_response_to_generate_content_response(
246316
llm_response = _model_response_to_generate_content_response(model_response)
247317

248318
llm_response = self.adapt_responses_api(
319+
llm_request,
249320
model_response,
250321
llm_response,
251322
)
@@ -254,6 +325,7 @@ def openai_response_to_generate_content_response(
254325

255326
def adapt_responses_api(
256327
self,
328+
llm_request: LlmRequest,
257329
model_response: ModelResponse,
258330
llm_response: LlmResponse,
259331
stream: bool = False,
@@ -262,9 +334,21 @@ def adapt_responses_api(
262334
Adapt responses api.
263335
"""
264336
if not model_response.id.startswith("chatcmpl"):
265-
if llm_response.custom_metadata is None:
266-
llm_response.custom_metadata = {}
267-
llm_response.custom_metadata["response_id"] = model_response["id"]
337+
# if llm_response.custom_metadata is None:
338+
# llm_response.custom_metadata = {}
339+
# llm_response.custom_metadata["response_id"] = model_response["id"]
340+
previous_response_id = model_response["id"]
341+
if not llm_request.cache_metadata:
342+
llm_response.cache_metadata = build_cache_metadata(
343+
{llm_request.config.labels["adk_agent_name"]: previous_response_id}
344+
)
345+
else:
346+
llm_response.cache_metadata = update_cache_metadata(
347+
llm_request.cache_metadata,
348+
llm_request.config.labels["adk_agent_name"],
349+
previous_response_id,
350+
)
351+
268352
# add responses cache data
269353
if not stream:
270354
if model_response.get("usage", {}).get("prompt_tokens_details"):

0 commit comments

Comments
 (0)