|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -import os |
16 | | -from typing import Literal |
| 15 | +from typing import Any, Callable, Literal |
17 | 16 |
|
18 | 17 | from google.adk.sessions import DatabaseSessionService, InMemorySessionService |
19 | | - |
20 | | -from veadk.config import getenv |
| 18 | +from pydantic import BaseModel, Field |
| 19 | + |
| 20 | +from veadk.memory.short_term_memory_backends.mysql_backend import ( |
| 21 | + MysqlSTMBackend, |
| 22 | +) |
| 23 | +from veadk.memory.short_term_memory_backends.postgresql_backend import ( |
| 24 | + PostgreSqlSTMBackend, |
| 25 | +) |
| 26 | +from veadk.memory.short_term_memory_backends.redis_backend import RedisSTMBackend |
| 27 | +from veadk.memory.short_term_memory_backends.sqlite_backend import ( |
| 28 | + SQLiteSTMBackend, |
| 29 | +) |
21 | 30 | from veadk.utils.logger import get_logger |
22 | 31 |
|
23 | | -from .short_term_memory_processor import ShortTermMemoryProcessor |
| 32 | +# from .short_term_memory_processor import ShortTermMemoryProcessor |
24 | 33 |
|
25 | 34 | logger = get_logger(__name__) |
26 | 35 |
|
27 | 36 | DEFAULT_LOCAL_DATABASE_PATH = "/tmp/veadk_local_database.db" |
28 | 37 |
|
29 | 38 |
|
30 | | -class ShortTermMemory: |
31 | | - """ |
32 | | - Short term memory class. |
| 39 | +class ShortTermMemory(BaseModel): |
| 40 | + backend: Literal["local", "mysql", "sqlite", "redis", "database"] = "local" |
| 41 | + """Short term memory backend. `Local` for in-memory storage, `redis` for redis storage, `mysql` for mysql / PostgreSQL storage. `sqlite` for sqlite storage.""" |
33 | 42 |
|
34 | | - This class is used to store short term memory. |
35 | | - """ |
| 43 | + backend_configs: dict = Field(default_factory=dict) |
| 44 | + """Backend specific configurations.""" |
36 | 45 |
|
37 | | - def __init__( |
38 | | - self, |
39 | | - backend: Literal["local", "database", "mysql"] = "local", |
40 | | - db_url: str = "", |
41 | | - enable_memory_optimization: bool = False, |
42 | | - ): |
43 | | - self.backend = backend |
44 | | - self.db_url = db_url |
| 46 | + db_url: str = "" |
| 47 | + """Database connection URL, e.g. `sqlite:///./test.db`. Once set, it will override the `backend` parameter.""" |
45 | 48 |
|
46 | | - if self.backend == "mysql": |
47 | | - host = getenv("DATABASE_MYSQL_HOST") |
48 | | - user = getenv("DATABASE_MYSQL_USER") |
49 | | - password = getenv("DATABASE_MYSQL_PASSWORD") |
50 | | - database = getenv("DATABASE_MYSQL_DATABASE") |
51 | | - db_url = f"mysql+pymysql://{user}:{password}@{host}/{database}" |
| 49 | + after_load_memory_callbacks: list[Callable] | None = None |
| 50 | + """A list of callbacks to be called after loading memory from the backend. The callback function should accept `Session` as an input.""" |
52 | 51 |
|
53 | | - self.db_url = db_url |
54 | | - self.backend = "database" |
55 | | - |
56 | | - if self.backend == "local": |
57 | | - logger.warning( |
58 | | - f"Short term memory backend: {self.backend}, the history will be lost after application shutdown." |
59 | | - ) |
60 | | - self.session_service = InMemorySessionService() |
61 | | - elif self.backend == "database": |
62 | | - if self.db_url == "" or self.db_url is None: |
63 | | - logger.warning("The `db_url` is an empty or None string.") |
64 | | - self._use_default_database() |
65 | | - else: |
66 | | - try: |
67 | | - self.session_service = DatabaseSessionService(db_url=self.db_url) |
68 | | - logger.info("Connected to database with db_url.") |
69 | | - except Exception as e: |
70 | | - logger.error(f"Failed to connect to database, error: {e}.") |
71 | | - self._use_default_database() |
| 52 | + def model_post_init(self, __context: Any) -> None: |
| 53 | + if self.db_url: |
| 54 | + logger.info("The `db_url` is set, ignore `backend` option.") |
| 55 | + self.session_service = DatabaseSessionService(db_url=self.db_url) |
72 | 56 | else: |
73 | | - raise ValueError(f"Unknown short term memory backend: {self.backend}") |
74 | | - |
75 | | - if enable_memory_optimization and backend == "database": |
76 | | - self.processor = ShortTermMemoryProcessor() |
77 | | - intercept_get_session = self.processor.patch() |
78 | | - self.session_service.get_session = intercept_get_session( |
79 | | - self.session_service.get_session |
80 | | - ) |
81 | | - |
82 | | - def _use_default_database(self): |
83 | | - self.db_url = DEFAULT_LOCAL_DATABASE_PATH |
84 | | - logger.info(f"Using default local database {self.db_url}") |
85 | | - if not os.path.exists(self.db_url): |
86 | | - self.create_local_sqlite3_db(self.db_url) |
87 | | - self.session_service = DatabaseSessionService(db_url="sqlite:///" + self.db_url) |
88 | | - |
89 | | - def create_local_sqlite3_db(self, path: str): |
90 | | - import sqlite3 |
91 | | - |
92 | | - conn = sqlite3.connect(path) |
93 | | - conn.close() |
94 | | - logger.debug(f"Create local sqlite3 database {path} done.") |
| 57 | + if self.backend == "database": |
| 58 | + logger.warning( |
| 59 | + "Backend `database` is deprecated, use `sqlite` to create short term memory." |
| 60 | + ) |
| 61 | + match self.backend: |
| 62 | + case "local": |
| 63 | + self.session_service = InMemorySessionService() |
| 64 | + case "mysql": |
| 65 | + self.session_service = MysqlSTMBackend( |
| 66 | + **self.backend_configs |
| 67 | + ).session_service |
| 68 | + case "sqlite": |
| 69 | + self.session_service = SQLiteSTMBackend( |
| 70 | + local_path=DEFAULT_LOCAL_DATABASE_PATH |
| 71 | + ).session_service |
| 72 | + case "redis": |
| 73 | + self.session_service = RedisSTMBackend( |
| 74 | + **self.backend_configs |
| 75 | + ).session_service |
| 76 | + case "postgresql": |
| 77 | + self.session_service = PostgreSqlSTMBackend( |
| 78 | + **self.backend_configs |
| 79 | + ).session_service |
95 | 80 |
|
96 | 81 | async def create_session( |
97 | 82 | self, |
|
0 commit comments