Skip to content

Commit 7ed9278

Browse files
authored
feat: add db_kwargs for sql connection config (#338)
1 parent e92c133 commit 7ed9278

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

veadk/memory/short_term_memory.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class ShortTermMemory(BaseModel):
126126

127127
backend_configs: dict = Field(default_factory=dict)
128128

129+
db_kwargs: dict = Field(default_factory=dict)
130+
129131
db_url: str = ""
130132

131133
local_database_path: str = "/tmp/veadk_local_database.db"
@@ -143,7 +145,9 @@ def model_post_init(self, __context: Any) -> None:
143145
"Please encode `username` or `password` with `urllib.parse.quote_plus`. "
144146
"Examples: p@ssword→p%40ssword."
145147
)
146-
self._session_service = DatabaseSessionService(db_url=self.db_url)
148+
self._session_service = DatabaseSessionService(
149+
db_url=self.db_url, **self.db_kwargs
150+
)
147151
else:
148152
if self.backend == "database":
149153
logger.warning(
@@ -155,15 +159,15 @@ def model_post_init(self, __context: Any) -> None:
155159
self._session_service = InMemorySessionService()
156160
case "mysql":
157161
self._session_service = MysqlSTMBackend(
158-
**self.backend_configs
162+
db_kwargs=self.db_kwargs, **self.backend_configs
159163
).session_service
160164
case "sqlite":
161165
self._session_service = SQLiteSTMBackend(
162166
local_path=self.local_database_path
163167
).session_service
164168
case "postgresql":
165169
self._session_service = PostgreSqlSTMBackend(
166-
**self.backend_configs
170+
db_kwargs=self.db_kwargs, **self.backend_configs
167171
).session_service
168172

169173
if self.after_load_memory_callback:

veadk/memory/short_term_memory_backends/mysql_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
class MysqlSTMBackend(BaseShortTermMemoryBackend):
3636
mysql_config: MysqlConfig = Field(default_factory=MysqlConfig)
37+
db_kwargs: dict = Field(default_factory=dict)
3738

3839
def model_post_init(self, context: Any) -> None:
3940
encoded_username = quote_plus(self.mysql_config.user)
@@ -46,4 +47,4 @@ def model_post_init(self, context: Any) -> None:
4647
@cached_property
4748
@override
4849
def session_service(self) -> BaseSessionService:
49-
return DatabaseSessionService(db_url=self._db_url)
50+
return DatabaseSessionService(db_url=self._db_url, **self.db_kwargs)

veadk/memory/short_term_memory_backends/postgresql_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535
class PostgreSqlSTMBackend(BaseShortTermMemoryBackend):
3636
postgresql_config: PostgreSqlConfig = Field(default_factory=PostgreSqlConfig)
37+
db_kwargs: dict = Field(default_factory=dict)
3738

3839
def model_post_init(self, context: Any) -> None:
3940
encoded_username = quote_plus(self.postgresql_config.user)
@@ -46,4 +47,4 @@ def model_post_init(self, context: Any) -> None:
4647
@cached_property
4748
@override
4849
def session_service(self) -> BaseSessionService:
49-
return DatabaseSessionService(db_url=self._db_url)
50+
return DatabaseSessionService(db_url=self._db_url, **self.db_kwargs)

0 commit comments

Comments
 (0)