Skip to content

Commit f643eb4

Browse files
committed
feat: enable agent to automatically save session
1 parent a8e4830 commit f643eb4

File tree

2 files changed

+184
-0
lines changed

2 files changed

+184
-0
lines changed

veadk/agent.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class Agent(LlmAgent):
8585
short_term_memory (Optional[ShortTermMemory]): Session-based memory for temporary context.
8686
long_term_memory (Optional[LongTermMemory]): Cross-session memory for persistent user context.
8787
tracers (list[BaseTracer]): List of tracers used for telemetry and monitoring.
88+
enable_authz (bool): Whether to enable agent authorization checks.
89+
auto_save_session (bool): Whether to automatically save sessions to long-term memory.
8890
"""
8991

9092
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
@@ -140,6 +142,8 @@ class Agent(LlmAgent):
140142

141143
enable_authz: bool = False
142144

145+
auto_save_session: bool = False
146+
143147
def model_post_init(self, __context: Any) -> None:
144148
super().model_post_init(None) # for sub_agents init
145149

@@ -258,6 +262,27 @@ def model_post_init(self, __context: Any) -> None:
258262
if self.prompt_manager:
259263
self.instruction = self.prompt_manager.get_prompt
260264

265+
if self.auto_save_session:
266+
if self.long_term_memory is None:
267+
logger.warning(
268+
"auto_save_session is enabled, but long_term_memory is not initialized."
269+
)
270+
else:
271+
from veadk.tools.builtin_tools.save_session import (
272+
save_session_to_memory,
273+
)
274+
275+
if self.after_agent_callback:
276+
if isinstance(self.after_agent_callback, list):
277+
self.after_agent_callback.append(save_session_to_memory)
278+
else:
279+
self.after_agent_callback = [
280+
self.after_agent_callback,
281+
save_session_to_memory,
282+
]
283+
else:
284+
self.after_agent_callback = save_session_to_memory
285+
261286
logger.info(f"VeADK version: {VERSION}")
262287

263288
logger.info(f"{self.__class__.__name__} `{self.name}` init done.")
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import time
16+
from google.adk.agents.callback_context import CallbackContext
17+
from veadk.config import getenv
18+
from veadk.utils.logger import get_logger
19+
20+
logger = get_logger(__name__)
21+
22+
# Session-level cache for tracking save state
23+
# Format: {(app_name, user_id, session_id): {'last_save_time': float, 'last_event_count': int}}
24+
_session_save_cache: dict = {}
25+
26+
# Track active session per user to detect session switches
27+
# Format: {(app_name, user_id): session_id}
28+
_active_sessions: dict = {}
29+
30+
# Configurable thresholds
31+
MIN_MESSAGES_THRESHOLD = getenv(
32+
"MIN_MESSAGES_THRESHOLD", 10
33+
) # Minimum number of new messages before saving
34+
MIN_TIME_THRESHOLD = getenv(
35+
"MIN_TIME_THRESHOLD", 60
36+
) # Minimum seconds between saves (1 minute)
37+
38+
39+
async def save_session_to_memory(
40+
callback_context: CallbackContext,
41+
) -> None:
42+
"""Save the current session to long-term memory.
43+
44+
Args:
45+
callback_context: The callback context containing invocation information.
46+
47+
Returns:
48+
None
49+
"""
50+
try:
51+
agent = callback_context._invocation_context.agent
52+
53+
long_term_memory = getattr(agent, "long_term_memory", None)
54+
if not long_term_memory:
55+
logger.error(
56+
"Long-term memory is not initialized in agent, cannot save session to memory."
57+
)
58+
return None
59+
60+
app_name = callback_context._invocation_context.app_name
61+
user_id = callback_context._invocation_context.user_id
62+
session_id = callback_context._invocation_context.session.id
63+
session_service = callback_context._invocation_context.session_service
64+
65+
current_time = time.time()
66+
67+
# Detect session switch and force save previous session
68+
user_key = (app_name, user_id)
69+
previous_session_id = _active_sessions.get(user_key)
70+
71+
if previous_session_id and previous_session_id != session_id:
72+
logger.info(
73+
f"Session switch detected for user {user_id}: "
74+
f"{previous_session_id} -> {session_id}. "
75+
f"Force saving previous session."
76+
)
77+
old_session = await session_service.get_session(
78+
app_name=app_name,
79+
user_id=user_id,
80+
session_id=previous_session_id,
81+
)
82+
if old_session:
83+
old_events = getattr(old_session, "events", [])
84+
old_event_count = len(old_events)
85+
await long_term_memory.add_session_to_memory(old_session)
86+
old_cache_key = (app_name, user_id, previous_session_id)
87+
88+
_session_save_cache[old_cache_key] = {
89+
"last_save_time": current_time,
90+
"last_event_count": old_event_count,
91+
}
92+
logger.info(
93+
f"Previous session `{old_session.id}` saved to long term memory due to session switch."
94+
)
95+
96+
# Update active session
97+
_active_sessions[user_key] = session_id
98+
99+
session = await session_service.get_session(
100+
app_name=app_name,
101+
user_id=user_id,
102+
session_id=session_id,
103+
)
104+
105+
if not session:
106+
logger.error(
107+
f"Session {session_id} (app_name={app_name}, user_id={user_id}) not found in session service, cannot save to long-term memory."
108+
)
109+
return None
110+
111+
current_events = getattr(session, "events", [])
112+
current_event_count = len(current_events)
113+
# logger.debug(f"Current event count: {current_event_count}")
114+
115+
# Create cache key
116+
cache_key = (app_name, user_id, session_id)
117+
118+
cache_info = _session_save_cache.get(cache_key)
119+
120+
if cache_info:
121+
last_save_time = cache_info.get("last_save_time", 0)
122+
last_event_count = cache_info.get("last_event_count", 0)
123+
124+
time_elapsed = current_time - last_save_time
125+
new_events_count = current_event_count - last_event_count
126+
127+
# Check if we should skip save
128+
if (
129+
time_elapsed < MIN_TIME_THRESHOLD
130+
and new_events_count < MIN_MESSAGES_THRESHOLD
131+
):
132+
logger.info(
133+
f"Skipping save for session {session_id}: "
134+
f"only {new_events_count} new events (need {MIN_MESSAGES_THRESHOLD}) "
135+
f"and {time_elapsed:.1f}s elapsed (need {MIN_TIME_THRESHOLD}s)"
136+
)
137+
return None
138+
else:
139+
logger.info(f"First save for session {session_id}.")
140+
141+
# Save to long-term memory
142+
await long_term_memory.add_session_to_memory(session)
143+
144+
# Update cache
145+
_session_save_cache[cache_key] = {
146+
"last_save_time": current_time,
147+
"last_event_count": current_event_count,
148+
}
149+
150+
logger.info(f"Add session `{session.id}` to long term memory.")
151+
152+
return None
153+
154+
except AttributeError as e:
155+
logger.error(f"AttributeError while saving session to memory: {e}")
156+
return None
157+
except Exception as e:
158+
logger.error(f"Unexpected error while saving session to memory: {e}")
159+
return None

0 commit comments

Comments
 (0)