|
15 | 15 | from __future__ import annotations |
16 | 16 |
|
17 | 17 | import os |
18 | | -from typing import Optional, Union |
| 18 | +from typing import Optional, Union, AsyncGenerator |
19 | 19 |
|
20 | | -from google.adk.agents import LlmAgent, RunConfig |
| 20 | +from google.adk.agents import LlmAgent, RunConfig, InvocationContext |
21 | 21 | from google.adk.agents.base_agent import BaseAgent |
22 | 22 | from google.adk.agents.llm_agent import InstructionProvider, ToolUnion |
23 | 23 | from google.adk.agents.run_config import StreamingMode |
| 24 | +from google.adk.events import Event, EventActions |
24 | 25 | from google.adk.models.lite_llm import LiteLlm |
25 | 26 | from google.adk.runners import Runner |
26 | 27 | from google.genai import types |
|
36 | 37 | from veadk.knowledgebase import KnowledgeBase |
37 | 38 | from veadk.memory.long_term_memory import LongTermMemory |
38 | 39 | from veadk.memory.short_term_memory import ShortTermMemory |
39 | | -from veadk.models.ark_llm import add_previous_response_id, add_response_id |
| 40 | +from veadk.models.ark_llm import add_previous_response_id |
40 | 41 | from veadk.processors import BaseRunProcessor, NoOpRunProcessor |
41 | 42 | from veadk.prompts.agent_default_prompt import DEFAULT_DESCRIPTION, DEFAULT_INSTRUCTION |
42 | 43 | from veadk.tracing.base_tracer import BaseTracer |
@@ -212,11 +213,13 @@ def model_post_init(self, __context: Any) -> None: |
212 | 213 | if not self.before_model_callback: |
213 | 214 | self.before_model_callback = add_previous_response_id |
214 | 215 | else: |
215 | | - self.before_model_callback.append(add_previous_response_id) |
216 | | - if not self.after_tool_callback: |
217 | | - self.after_model_callback = add_response_id |
218 | | - else: |
219 | | - self.after_model_callback.append(add_response_id) |
| 216 | + if isinstance(self.before_model_callback, list): |
| 217 | + self.before_model_callback.append(add_previous_response_id) |
| 218 | + else: |
| 219 | + self.before_model_callback = [ |
| 220 | + self.before_model_callback, |
| 221 | + add_previous_response_id, |
| 222 | + ] |
220 | 223 | else: |
221 | 224 | self.model = LiteLlm( |
222 | 225 | model=f"{self.model_provider}/{self.model_name}", |
@@ -259,6 +262,27 @@ def model_post_init(self, __context: Any) -> None: |
259 | 262 | f"Agent: {self.model_dump(include={'name', 'model_name', 'model_api_base', 'tools'})}" |
260 | 263 | ) |
261 | 264 |
|
| 265 | + async def _run_async_impl( |
| 266 | + self, ctx: InvocationContext |
| 267 | + ) -> AsyncGenerator[Event, None]: |
| 268 | + async for event in super()._run_async_impl(ctx): |
| 269 | + agent_name = self.name |
| 270 | + if ( |
| 271 | + self.enable_responses |
| 272 | + and event.custom_metadata |
| 273 | + and event.custom_metadata.get("response_id") |
| 274 | + ): |
| 275 | + response_id = event.custom_metadata["response_id"] |
| 276 | + yield Event( |
| 277 | + invocation_id=ctx.invocation_id, |
| 278 | + author=self.name, |
| 279 | + actions=EventActions( |
| 280 | + state_delta={f"agent:{agent_name}:response_id": response_id} |
| 281 | + ), |
| 282 | + branch=ctx.branch, |
| 283 | + ) |
| 284 | + yield event |
| 285 | + |
262 | 286 | async def _run( |
263 | 287 | self, |
264 | 288 | runner, |
|
0 commit comments