Skip to content

Commit 894d4c3

Browse files
committed
fix: multi-agent bug
1 parent 8aa0ec7 commit 894d4c3

File tree

2 files changed

+48
-46
lines changed

2 files changed

+48
-46
lines changed

veadk/agent.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from __future__ import annotations
1616

1717
import os
18-
from typing import Optional, Union
18+
from typing import Optional, Union, AsyncGenerator
1919

20-
from google.adk.agents import LlmAgent, RunConfig
20+
from google.adk.agents import LlmAgent, RunConfig, InvocationContext
2121
from google.adk.agents.base_agent import BaseAgent
2222
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
2323
from google.adk.agents.run_config import StreamingMode
24+
from google.adk.events import Event, EventActions
2425
from google.adk.models.lite_llm import LiteLlm
2526
from google.adk.runners import Runner
2627
from google.genai import types
@@ -36,7 +37,7 @@
3637
from veadk.knowledgebase import KnowledgeBase
3738
from veadk.memory.long_term_memory import LongTermMemory
3839
from veadk.memory.short_term_memory import ShortTermMemory
39-
from veadk.models.ark_llm import add_previous_response_id, add_response_id
40+
from veadk.models.ark_llm import add_previous_response_id
4041
from veadk.processors import BaseRunProcessor, NoOpRunProcessor
4142
from veadk.prompts.agent_default_prompt import DEFAULT_DESCRIPTION, DEFAULT_INSTRUCTION
4243
from veadk.tracing.base_tracer import BaseTracer
@@ -212,11 +213,13 @@ def model_post_init(self, __context: Any) -> None:
212213
if not self.before_model_callback:
213214
self.before_model_callback = add_previous_response_id
214215
else:
215-
self.before_model_callback.append(add_previous_response_id)
216-
if not self.after_tool_callback:
217-
self.after_model_callback = add_response_id
218-
else:
219-
self.after_model_callback.append(add_response_id)
216+
if isinstance(self.before_model_callback, list):
217+
self.before_model_callback.append(add_previous_response_id)
218+
else:
219+
self.before_model_callback = [
220+
self.before_model_callback,
221+
add_previous_response_id,
222+
]
220223
else:
221224
self.model = LiteLlm(
222225
model=f"{self.model_provider}/{self.model_name}",
@@ -259,6 +262,27 @@ def model_post_init(self, __context: Any) -> None:
259262
f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}"
260263
)
261264

265+
async def _run_async_impl(
266+
self, ctx: InvocationContext
267+
) -> AsyncGenerator[Event, None]:
268+
async for event in super()._run_async_impl(ctx):
269+
agent_name = self.name
270+
if (
271+
self.enable_responses
272+
and event.custom_metadata
273+
and event.custom_metadata.get("response_id")
274+
):
275+
response_id = event.custom_metadata["response_id"]
276+
yield Event(
277+
invocation_id=ctx.invocation_id,
278+
author=self.name,
279+
actions=EventActions(
280+
state_delta={f"agent:{agent_name}:response_id": response_id}
281+
),
282+
branch=ctx.branch,
283+
)
284+
yield event
285+
262286
async def _run(
263287
self,
264288
runner,

veadk/models/ark_llm.py

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -257,45 +257,23 @@ async def generate_content_async(
257257
def add_previous_response_id(
258258
callback_context: CallbackContext, llm_request: LlmRequest
259259
) -> Optional[LlmResponse]:
260-
invocation_context = callback_context._invocation_context
261-
events = invocation_context.session.events
262260
agent_name = callback_context.agent_name
263261
# read response_id
264-
if (
265-
events
266-
and len(events) >= 2
267-
and events[-2].custom_metadata
268-
and "response_id" in events[-2].custom_metadata
269-
):
270-
previous_response_id = callback_context.state.get(
271-
f"agent:{agent_name}:response_id"
262+
previous_response_id = callback_context.state.get(f"agent:{agent_name}:response_id")
263+
if "contents_count" in CacheMetadata.model_fields: # adk >= 1.17
264+
llm_request.cache_metadata = CacheMetadata(
265+
cache_name=previous_response_id,
266+
expire_time=0,
267+
fingerprint="",
268+
invocations_used=0,
269+
contents_count=0,
270+
)
271+
else: # 1.15 <= adk < 1.17
272+
llm_request.cache_metadata = CacheMetadata(
273+
cache_name=previous_response_id,
274+
expire_time=0,
275+
fingerprint="",
276+
invocations_used=0,
277+
cached_contents_count=0,
272278
)
273-
# previous_response_id = events[-2].custom_metadata["response_id"]
274-
if "contents_count" in CacheMetadata.model_fields: # adk >= 1.17
275-
llm_request.cache_metadata = CacheMetadata(
276-
cache_name=previous_response_id,
277-
expire_time=0,
278-
fingerprint="",
279-
invocations_used=0,
280-
contents_count=0,
281-
)
282-
else: # 1.15 <= adk < 1.17
283-
llm_request.cache_metadata = CacheMetadata(
284-
cache_name=previous_response_id,
285-
expire_time=0,
286-
fingerprint="",
287-
invocations_used=0,
288-
cached_contents_count=0,
289-
)
290-
return
291-
292-
293-
# after_model_callback
294-
def add_response_id(
295-
callback_context: CallbackContext, llm_response: LlmResponse
296-
) -> Optional[LlmResponse]:
297-
agent_name = callback_context.agent_name
298-
299-
response_id = llm_response.custom_metadata["response_id"]
300-
callback_context.state[f"agent:{agent_name}:response_id"] = response_id
301279
return

0 commit comments

Comments
 (0)