Skip to content

Commit e213841

Browse files
committed
update short term memory
1 parent 51542c1 commit e213841

File tree

9 files changed

+289
-65
lines changed

9 files changed

+289
-65
lines changed

veadk/configs/database_configs.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ class MysqlConfig(BaseSettings):
4747
charset: str = "utf8"
4848

4949

50+
class PostgreSqlConfig(BaseSettings):
51+
model_config = SettingsConfigDict(env_prefix="DATABASE_POSTGRESQL_")
52+
53+
host: str = ""
54+
55+
port: int = 5432
56+
57+
user: str = ""
58+
59+
password: str = ""
60+
61+
database: str = ""
62+
63+
5064
class RedisConfig(BaseSettings):
5165
model_config = SettingsConfigDict(env_prefix="DATABASE_REDIS_")
5266

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

veadk/memory/short_term_memory.py

Lines changed: 50 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -12,86 +12,71 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
16-
from typing import Literal
15+
from typing import Any, Callable, Literal
1716

1817
from google.adk.sessions import DatabaseSessionService, InMemorySessionService
19-
20-
from veadk.config import getenv
18+
from pydantic import BaseModel, Field
19+
20+
from veadk.memory.short_term_memory_backends.mysql_backend import (
21+
MysqlSTMBackend,
22+
)
23+
from veadk.memory.short_term_memory_backends.postgresql_backend import (
24+
PostgreSqlSTMBackend,
25+
)
26+
from veadk.memory.short_term_memory_backends.redis_backend import RedisSTMBackend
27+
from veadk.memory.short_term_memory_backends.sqlite_backend import (
28+
SQLiteSTMBackend,
29+
)
2130
from veadk.utils.logger import get_logger
2231

23-
from .short_term_memory_processor import ShortTermMemoryProcessor
32+
# from .short_term_memory_processor import ShortTermMemoryProcessor
2433

2534
logger = get_logger(__name__)
2635

2736
DEFAULT_LOCAL_DATABASE_PATH = "/tmp/veadk_local_database.db"
2837

2938

30-
class ShortTermMemory:
31-
"""
32-
Short term memory class.
39+
class ShortTermMemory(BaseModel):
40+
backend: Literal["local", "mysql", "sqlite", "redis", "database"] = "local"
41+
"""Short term memory backend. `Local` for in-memory storage, `redis` for redis storage, `mysql` for mysql / PostgreSQL storage. `sqlite` for sqlite storage."""
3342

34-
This class is used to store short term memory.
35-
"""
43+
backend_configs: dict = Field(default_factory=dict)
44+
"""Backend specific configurations."""
3645

37-
def __init__(
38-
self,
39-
backend: Literal["local", "database", "mysql"] = "local",
40-
db_url: str = "",
41-
enable_memory_optimization: bool = False,
42-
):
43-
self.backend = backend
44-
self.db_url = db_url
46+
db_url: str = ""
47+
"""Database connection URL, e.g. `sqlite:///./test.db`. Once set, it will override the `backend` parameter."""
4548

46-
if self.backend == "mysql":
47-
host = getenv("DATABASE_MYSQL_HOST")
48-
user = getenv("DATABASE_MYSQL_USER")
49-
password = getenv("DATABASE_MYSQL_PASSWORD")
50-
database = getenv("DATABASE_MYSQL_DATABASE")
51-
db_url = f"mysql+pymysql://{user}:{password}@{host}/{database}"
49+
after_load_memory_callbacks: list[Callable] | None = None
50+
"""A list of callbacks to be called after loading memory from the backend. The callback function should accept `Session` as an input."""
5251

53-
self.db_url = db_url
54-
self.backend = "database"
55-
56-
if self.backend == "local":
57-
logger.warning(
58-
f"Short term memory backend: {self.backend}, the history will be lost after application shutdown."
59-
)
60-
self.session_service = InMemorySessionService()
61-
elif self.backend == "database":
62-
if self.db_url == "" or self.db_url is None:
63-
logger.warning("The `db_url` is an empty or None string.")
64-
self._use_default_database()
65-
else:
66-
try:
67-
self.session_service = DatabaseSessionService(db_url=self.db_url)
68-
logger.info("Connected to database with db_url.")
69-
except Exception as e:
70-
logger.error(f"Failed to connect to database, error: {e}.")
71-
self._use_default_database()
52+
def model_post_init(self, __context: Any) -> None:
53+
if self.db_url:
54+
logger.info("The `db_url` is set, ignore `backend` option.")
55+
self.session_service = DatabaseSessionService(db_url=self.db_url)
7256
else:
73-
raise ValueError(f"Unknown short term memory backend: {self.backend}")
74-
75-
if enable_memory_optimization and backend == "database":
76-
self.processor = ShortTermMemoryProcessor()
77-
intercept_get_session = self.processor.patch()
78-
self.session_service.get_session = intercept_get_session(
79-
self.session_service.get_session
80-
)
81-
82-
def _use_default_database(self):
83-
self.db_url = DEFAULT_LOCAL_DATABASE_PATH
84-
logger.info(f"Using default local database {self.db_url}")
85-
if not os.path.exists(self.db_url):
86-
self.create_local_sqlite3_db(self.db_url)
87-
self.session_service = DatabaseSessionService(db_url="sqlite:///" + self.db_url)
88-
89-
def create_local_sqlite3_db(self, path: str):
90-
import sqlite3
91-
92-
conn = sqlite3.connect(path)
93-
conn.close()
94-
logger.debug(f"Create local sqlite3 database {path} done.")
57+
if self.backend == "database":
58+
logger.warning(
59+
"Backend `database` is deprecated, use `sqlite` to create short term memory."
60+
)
61+
match self.backend:
62+
case "local":
63+
self.session_service = InMemorySessionService()
64+
case "mysql":
65+
self.session_service = MysqlSTMBackend(
66+
**self.backend_configs
67+
).session_service
68+
case "sqlite":
69+
self.session_service = SQLiteSTMBackend(
70+
local_path=DEFAULT_LOCAL_DATABASE_PATH
71+
).session_service
72+
case "redis":
73+
self.session_service = RedisSTMBackend(
74+
**self.backend_configs
75+
).session_service
76+
case "postgresql":
77+
self.session_service = PostgreSqlSTMBackend(
78+
**self.backend_configs
79+
).session_service
9580

9681
async def create_session(
9782
self,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from abc import ABC, abstractmethod
17+
from functools import cached_property
18+
19+
from google.adk.sessions import BaseSessionService
20+
from pydantic import BaseModel
21+
22+
23+
class BaseShortTermMemoryBackend(ABC, BaseModel):
24+
"""
25+
Base class for short term memory backend.
26+
"""
27+
28+
@cached_property
29+
@abstractmethod
30+
def session_service(self) -> BaseSessionService:
31+
"""Return the session service instance."""
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import cached_property
16+
from typing import Any
17+
18+
from google.adk.sessions import (
19+
BaseSessionService,
20+
DatabaseSessionService,
21+
)
22+
from pydantic import Field
23+
from typing_extensions import override
24+
25+
from veadk.configs.database_configs import MysqlConfig
26+
from veadk.memory.short_term_memory_backends.base_backend import (
27+
BaseShortTermMemoryBackend,
28+
)
29+
30+
31+
class MysqlSTMBackend(BaseShortTermMemoryBackend):
32+
mysql_config: MysqlConfig = Field(default_factory=MysqlConfig)
33+
34+
def model_post_init(self, context: Any) -> None:
35+
self._db_url = f"mysql+pymysql://{self.mysql_config.user}:{self.mysql_config.password}@{self.mysql_config.host}/{self.mysql_config.database}"
36+
37+
@cached_property
38+
@override
39+
def session_service(self) -> BaseSessionService:
40+
return DatabaseSessionService(db_url=self._db_url)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import cached_property
16+
from typing import Any
17+
18+
from google.adk.sessions import (
19+
BaseSessionService,
20+
DatabaseSessionService,
21+
)
22+
from pydantic import Field
23+
from typing_extensions import override
24+
25+
from veadk.configs.database_configs import PostgreSqlConfig
26+
from veadk.memory.short_term_memory_backends.base_backend import (
27+
BaseShortTermMemoryBackend,
28+
)
29+
30+
31+
class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
32+
postgresql_config: PostgreSqlConfig = Field(default_factory=PostgreSqlConfig)
33+
34+
def model_post_init(self, context: Any) -> None:
35+
self._db_url = f"postgresql://{self.postgresql_config.user}:{self.postgresql_config.password}@{self.postgresql_config.host}:{self.postgresql_config.port}/{self.postgresql_config.database}"
36+
37+
@cached_property
38+
@override
39+
def session_service(self) -> BaseSessionService:
40+
return DatabaseSessionService(db_url=self._db_url)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import cached_property
16+
from typing import Any
17+
18+
from google.adk.sessions import (
19+
BaseSessionService,
20+
DatabaseSessionService,
21+
)
22+
from pydantic import Field
23+
from typing_extensions import override
24+
25+
from veadk.configs.database_configs import RedisConfig
26+
from veadk.memory.short_term_memory_backends.base_backend import (
27+
BaseShortTermMemoryBackend,
28+
)
29+
30+
31+
class RedisSTMBackend(BaseShortTermMemoryBackend):
32+
redis_config: RedisConfig = Field(default_factory=RedisConfig)
33+
34+
def model_post_init(self, context: Any) -> None:
35+
self._db_url = f"redis://{self.redis_config.host}:{self.redis_config.port}/{self.redis_config.db}"
36+
37+
@cached_property
38+
@override
39+
def session_service(self) -> BaseSessionService:
40+
return DatabaseSessionService(db_url=self._db_url)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sqlite3
17+
from functools import cached_property
18+
from typing import Any
19+
20+
from google.adk.sessions import (
21+
BaseSessionService,
22+
DatabaseSessionService,
23+
)
24+
from typing_extensions import override
25+
26+
from veadk.memory.short_term_memory_backends.base_backend import (
27+
BaseShortTermMemoryBackend,
28+
)
29+
30+
31+
class SQLiteSTMBackend(BaseShortTermMemoryBackend):
32+
local_path: str
33+
34+
def model_post_init(self, context: Any) -> None:
35+
# if the DB file not exists, create it
36+
if not self._db_exists():
37+
conn = sqlite3.connect(self.local_path)
38+
conn.close()
39+
40+
self._db_url = f"sqlite:///{self.local_path}"
41+
42+
@cached_property
43+
@override
44+
def session_service(self) -> BaseSessionService:
45+
return DatabaseSessionService(db_url=self._db_url)
46+
47+
def _db_exists(self) -> bool:
48+
return os.path.exists(self.local_path)

0 commit comments

Comments
 (0)