Skip to content

Commit ed2ff9f

Browse files
authored
feat: add kwargs for long-term-memory viking (#340)
1 parent 7ed9278 commit ed2ff9f

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

veadk/memory/long_term_memory.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ def _filter_and_convert_events(self, events: list[Event]) -> list[str]:
249249
async def add_session_to_memory(
250250
self,
251251
session: Session,
252+
**kwargs,
252253
):
253254
"""Add a chat session's events to the long-term memory backend.
254255
@@ -283,7 +284,12 @@ async def add_session_to_memory(
283284
logger.info(
284285
f"Adding {len(event_strings)} events to long term memory: index={self.index}"
285286
)
286-
self._backend.save_memory(user_id=user_id, event_strings=event_strings)
287+
if self.backend == "viking":
288+
self._backend.save_memory(
289+
user_id=user_id, event_strings=event_strings, **kwargs
290+
)
291+
else:
292+
self._backend.save_memory(user_id=user_id, event_strings=event_strings)
287293
logger.info(
288294
f"Added {len(event_strings)} events to long term memory: index={self.index}, user_id={user_id}"
289295
)

veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _get_client(self) -> VikingDBMemoryClient:
120120

121121
@override
122122
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
123+
assistant_id = kwargs.get("assistant_id", "assistant")
123124
session_id = str(uuid.uuid1())
124125
messages = []
125126
for raw_events in event_strings:
@@ -131,7 +132,7 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
131132
messages.append({"role": role, "content": content})
132133
metadata = {
133134
"default_user_id": user_id,
134-
"default_assistant_id": "assistant",
135+
"default_assistant_id": assistant_id,
135136
"time": int(time.time() * 1000),
136137
}
137138

veadk/runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ async def save_eval_set(self, session_id: str, eval_set_id: str = "default") ->
674674
return eval_set_path
675675

676676
async def save_session_to_long_term_memory(
677-
self, session_id: str, user_id: str = "", app_name: str = ""
677+
self, session_id: str, user_id: str = "", app_name: str = "", **kwargs
678678
) -> None:
679679
"""Save the specified session to long-term memory.
680680
@@ -730,5 +730,5 @@ async def save_session_to_long_term_memory(
730730
)
731731
return
732732

733-
await self.long_term_memory.add_session_to_memory(session)
733+
await self.long_term_memory.add_session_to_memory(session, kwargs=kwargs)
734734
logger.info(f"Add session `{session.id}` to long term memory.")

0 commit comments

Comments
 (0)