Skip to content

Commit de6456e

Browse files
committed
add callback mount
1 parent e213841 commit de6456e

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

veadk/memory/short_term_memory.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from functools import wraps
1516
from typing import Any, Callable, Literal
1617

1718
from google.adk.sessions import DatabaseSessionService, InMemorySessionService
@@ -36,6 +37,18 @@
3637
DEFAULT_LOCAL_DATABASE_PATH = "/tmp/veadk_local_database.db"
3738

3839

40+
def wrap_get_session_with_callbacks(obj, callback_fn: Callable):
41+
get_session_fn = getattr(obj, "get_session")
42+
43+
@wraps(get_session_fn)
44+
def wrapper(*args, **kwargs):
45+
result = get_session_fn(*args, **kwargs)
46+
callback_fn(result, *args, **kwargs)
47+
return result
48+
49+
setattr(obj, "get_session", wrapper)
50+
51+
3952
class ShortTermMemory(BaseModel):
4053
backend: Literal["local", "mysql", "sqlite", "redis", "database"] = "local"
4154
"""Short term memory backend. `Local` for in-memory storage, `redis` for redis storage, `mysql` for mysql / PostgreSQL storage. `sqlite` for sqlite storage."""
@@ -46,8 +59,8 @@ class ShortTermMemory(BaseModel):
4659
db_url: str = ""
4760
"""Database connection URL, e.g. `sqlite:///./test.db`. Once set, it will override the `backend` parameter."""
4861

49-
after_load_memory_callbacks: list[Callable] | None = None
50-
"""A list of callbacks to be called after loading memory from the backend. The callback function should accept `Session` as an input."""
62+
after_load_memory_callback: Callable | None = None
63+
"""A callback to be called after loading memory from the backend. The callback function should accept `Session` as an input."""
5164

5265
def model_post_init(self, __context: Any) -> None:
5366
if self.db_url:
@@ -78,6 +91,11 @@ def model_post_init(self, __context: Any) -> None:
7891
**self.backend_configs
7992
).session_service
8093

94+
if self.after_load_memory_callback:
95+
wrap_get_session_with_callbacks(
96+
self.session_service, self.after_load_memory_callback
97+
)
98+
8199
async def create_session(
82100
self,
83101
app_name: str,

0 commit comments

Comments
 (0)