Skip to content

Commit 5045497

Browse files
committed
feat: enable persistent short-term memory to pass response-id
1 parent c86cdbb commit 5045497

File tree

2 files changed

+36
-3
lines changed

2 files changed

+36
-3
lines changed

veadk/agent.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from google.adk.agents.context_cache_config import ContextCacheConfig
3030
from google.adk.agents.llm_agent import InstructionProvider, ToolUnion
3131
from google.adk.agents.run_config import StreamingMode
32-
from google.adk.events import Event
32+
from google.adk.events import Event, EventActions
3333
from google.adk.models.lite_llm import LiteLlm
3434
from google.adk.runners import Runner
3535
from google.genai import types
@@ -240,6 +240,18 @@ async def _run_async_impl(
240240

241241
async for event in super()._run_async_impl(ctx):
242242
yield event
243+
if self.enable_responses and event.cache_metadata:
244+
# for persistent short-term memory with response api
245+
session_state_event = Event(
246+
invocation_id=event.invocation_id,
247+
author=event.author,
248+
actions=EventActions(
249+
state_delta={
250+
"response_id": event.cache_metadata.cache_name,
251+
}
252+
),
253+
)
254+
yield session_state_event
243255

244256
async def _run(
245257
self,

veadk/memory/short_term_memory.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from veadk.memory.short_term_memory_backends.sqlite_backend import (
3333
SQLiteSTMBackend,
3434
)
35+
from veadk.models.ark_transform import build_cache_metadata
3536
from veadk.utils.logger import get_logger
3637

3738
logger = get_logger(__name__)
@@ -41,14 +42,29 @@ def wrap_get_session_with_callbacks(obj, callback_fn: Callable):
4142
get_session_fn = getattr(obj, "get_session")
4243

4344
@wraps(get_session_fn)
44-
def wrapper(*args, **kwargs):
45-
result = get_session_fn(*args, **kwargs)
45+
async def wrapper(*args, **kwargs):
46+
result = await get_session_fn(*args, **kwargs)
4647
callback_fn(result, *args, **kwargs)
4748
return result
4849

4950
setattr(obj, "get_session", wrapper)
5051

5152

53+
def enable_responses_api_for_session_service(result, *args, **kwargs):
54+
if result and isinstance(result, Session):
55+
if result.events:
56+
for event in result.events:
57+
if (
58+
event.actions
59+
and event.actions.state_delta
60+
and not event.cache_metadata
61+
and "response_id" in event.actions.state_delta
62+
):
63+
event.cache_metadata = build_cache_metadata(
64+
response_id=event.actions.state_delta.get("response_id"),
65+
)
66+
67+
5268
class ShortTermMemory(BaseModel):
5369
"""Short term memory for agent execution.
5470
@@ -170,6 +186,11 @@ def model_post_init(self, __context: Any) -> None:
170186
db_kwargs=self.db_kwargs, **self.backend_configs
171187
).session_service
172188

189+
if self.backend != "local":
190+
wrap_get_session_with_callbacks(
191+
self._session_service, enable_responses_api_for_session_service
192+
)
193+
173194
if self.after_load_memory_callback:
174195
wrap_get_session_with_callbacks(
175196
self._session_service, self.after_load_memory_callback

0 commit comments

Comments
 (0)