Skip to content

Commit 8aa0ec7

Browse files
committed
fix: transport response_id by session state
1 parent bd6658d commit 8aa0ec7

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

veadk/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from veadk.knowledgebase import KnowledgeBase
3737
from veadk.memory.long_term_memory import LongTermMemory
3838
from veadk.memory.short_term_memory import ShortTermMemory
39-
from veadk.models.ark_llm import add_previous_response_id
39+
from veadk.models.ark_llm import add_previous_response_id, add_response_id
4040
from veadk.processors import BaseRunProcessor, NoOpRunProcessor
4141
from veadk.prompts.agent_default_prompt import DEFAULT_DESCRIPTION, DEFAULT_INSTRUCTION
4242
from veadk.tracing.base_tracer import BaseTracer
@@ -213,6 +213,10 @@ def model_post_init(self, __context: Any) -> None:
213213
self.before_model_callback = add_previous_response_id
214214
else:
215215
self.before_model_callback.append(add_previous_response_id)
216+
if not self.after_tool_callback:
217+
self.after_model_callback = add_response_id
218+
else:
219+
self.after_model_callback.append(add_response_id)
216220
else:
217221
self.model = LiteLlm(
218222
model=f"{self.model_provider}/{self.model_name}",

veadk/models/ark_llm.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,18 @@ def add_previous_response_id(
259259
) -> Optional[LlmResponse]:
260260
invocation_context = callback_context._invocation_context
261261
events = invocation_context.session.events
262+
agent_name = callback_context.agent_name
263+
# read response_id
262264
if (
263265
events
264266
and len(events) >= 2
265267
and events[-2].custom_metadata
266268
and "response_id" in events[-2].custom_metadata
267269
):
268-
previous_response_id = events[-2].custom_metadata["response_id"]
270+
previous_response_id = callback_context.state.get(
271+
f"agent:{agent_name}:response_id"
272+
)
273+
# previous_response_id = events[-2].custom_metadata["response_id"]
269274
if "contents_count" in CacheMetadata.model_fields: # adk >= 1.17
270275
llm_request.cache_metadata = CacheMetadata(
271276
cache_name=previous_response_id,
@@ -283,3 +288,14 @@ def add_previous_response_id(
283288
cached_contents_count=0,
284289
)
285290
return
291+
292+
293+
# after_model_callback
294+
def add_response_id(
295+
callback_context: CallbackContext, llm_response: LlmResponse
296+
) -> Optional[LlmResponse]:
297+
agent_name = callback_context.agent_name
298+
299+
response_id = llm_response.custom_metadata["response_id"]
300+
callback_context.state[f"agent:{agent_name}:response_id"] = response_id
301+
return

0 commit comments

Comments
 (0)