Skip to content

Commit a39bdfb

Browse files
committed
chore(ltm): optimize memory structure
1 parent c1e5984 commit a39bdfb

File tree

10 files changed

+177
-244
lines changed

10 files changed

+177
-244
lines changed

docs/content/6.memory/3.long-term-memory.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ navigation:
77

88
## 使用方法
99

10-
VeADK 的长期记忆通常存储在数据库中,通过如下方式定义一个长期记忆:
10+
VeADK 的长期记忆通常存储在数据库中,你需要在初始化长期记忆时定义 `index` 来指定后端索引名称。通过如下方式定义一个长期记忆:
1111

1212
```python
1313
from veadk.memory.long_term_memory import LongTermMemory
1414

15-
# 由于长期记忆需要构建索引,因此你必须在初始化长期记忆时定义 `app_name` 以及 `user_id`
16-
long_term_memory = LongTermMemory(app_name="my_app_name", user_id="user_id")
15+
#
16+
long_term_memory = LongTermMemory(index="my_index")
1717
```
1818

1919
通过如下例子说明长期记忆:
@@ -32,7 +32,7 @@ user_id = "temp_user"
3232
teaching_session_id = "teaching_session"
3333
student_session_id = "student_session"
3434

35-
long_term_memory = LongTermMemory(backend="local", app_name=app_name, user_id=user_id)
35+
long_term_memory = LongTermMemory(backend="local", index=app_name)
3636

3737
agent = Agent(long_term_memory=long_term_memory)
3838

@@ -90,8 +90,4 @@ print(response)
9090
::field{name="app_name" type="string"}
9191
Agent 应用名称,用于多应用区分。默认空字符串。
9292
::
93-
94-
::field{name="user_id" type="string"}
95-
Agent 用户 ID,用于区分不同用户的长期记忆。默认空字符串。
96-
::
9793
::

docs/content/7.knowledgebase/1.knowledgebase.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ navigation:
99

1010
VeADK 基于 Llama-index 作为知识库的主要处理入口。开发者可上传文本、文件、目录,我们会为您进行自动切片。
1111

12-
创建知识库时,您必须要提供您的 `app_name`(将会用来自动构建索引名称),或指定一个知识库的索引
12+
创建知识库时,您必须要提供您的知识库后端索引名称 `index`,或指定 `app_name` 来作为索引名称
1313

1414
```python
1515
from veadk.knowledgebase import KnowledgeBase

veadk/knowledgebase/knowledgebase.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ class KnowledgeBase(BaseModel):
7373
"""Configuration for the backend"""
7474

7575
top_k: int = 10
76-
"""Number of top similar documents to retrieve during search.
77-
78-
Default is 10."""
76+
"""Number of top similar documents to retrieve during search"""
7977

8078
app_name: str = ""
8179

@@ -90,33 +88,18 @@ def model_post_init(self, __context: Any) -> None:
9088
)
9189
return
9290

93-
# must provide at least one of them
94-
if not self.app_name and not self.index:
95-
raise ValueError(
96-
"Either `app_name` or `index` must be provided one of them."
97-
)
98-
99-
# priority use index
100-
if self.app_name and self.index:
101-
logger.warning(
102-
"`app_name` and `index` are both provided, using `index` as the knowledgebase index name."
103-
)
104-
105-
# generate index name if `index` not provided but `app_name` is provided
106-
if self.app_name and not self.index:
107-
self.index = build_knowledgebase_index(self.app_name)
108-
logger.info(
109-
f"Knowledgebase index is set to {self.index} (generated by the app_name: {self.app_name})."
110-
)
91+
self.index = self.index or self.app_name
92+
if not self.index:
93+
raise ValueError("Either `index` or `app_name` must be provided.")
11194

11295
logger.info(
113-
f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}"
96+
f"Initializing knowledgebase: backend={self.backend} index={self.index} top_k={self.top_k}"
11497
)
11598
self._backend = _get_backend_cls(self.backend)(
11699
index=self.index, **self.backend_config if self.backend_config else {}
117100
)
118101
logger.info(
119-
f"Initialized knowledgebase with backend {self._backend.__class__.__name__}"
102+
f"Initialized knowledgebase with backend {self.backend.__class__.__name__}"
120103
)
121104

122105
def add_from_directory(self, directory: str, **kwargs) -> bool:
@@ -133,8 +116,7 @@ def add_from_text(self, text: str | list[str], **kwargs) -> bool:
133116

134117
def search(self, query: str, top_k: int = 0, **kwargs) -> list[KnowledgebaseEntry]:
135118
"""Search knowledge from knowledgebase"""
136-
if top_k == 0:
137-
top_k = self.top_k
119+
top_k = top_k if top_k != 0 else self.top_k
138120

139121
_entries = self._backend.search(query=query, top_k=top_k, **kwargs)
140122

veadk/memory/long_term_memory.py

Lines changed: 32 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,6 @@ def _get_backend_cls(backend: str) -> type[BaseLongTermMemoryBackend]:
7272
raise ValueError(f"Unsupported long term memory backend: {backend}")
7373

7474

75-
def build_long_term_memory_index(app_name: str, user_id: str):
76-
return f"{app_name}_{user_id}"
77-
78-
7975
class LongTermMemory(BaseMemoryService, BaseModel):
8076
backend: Union[
8177
Literal["local", "opensearch", "redis", "viking", "viking_mem", "mem0"],
@@ -89,19 +85,14 @@ class LongTermMemory(BaseMemoryService, BaseModel):
8985
top_k: int = 5
9086
"""Number of top similar documents to retrieve during search."""
9187

88+
index: str = ""
89+
9290
app_name: str = ""
9391

9492
user_id: str = ""
93+
"""Deprecated attribute"""
9594

9695
def model_post_init(self, __context: Any) -> None:
97-
if self.backend == "viking_mem":
98-
logger.warning(
99-
"The `viking_mem` backend is deprecated, please use `viking` instead."
100-
)
101-
self.backend = "viking"
102-
103-
self._backend = None
104-
10596
# Once user define a backend instance, use it directly
10697
if isinstance(self.backend, BaseLongTermMemoryBackend):
10798
self._backend = self.backend
@@ -110,33 +101,23 @@ def model_post_init(self, __context: Any) -> None:
110101
)
111102
return
112103

113-
if self.backend_config:
114-
logger.warning(
115-
f"Initialized long term memory backend {self.backend} with config. We will ignore `app_name` and `user_id` if provided."
104+
# Check index
105+
self.index = self.index or self.app_name
106+
if not self.index:
107+
raise ValueError(
108+
"Attribute `index` or `app_name` must be provided one of both."
116109
)
117-
self._backend = _get_backend_cls(self.backend)(**self.backend_config)
118-
_index = self.backend_config.get("index", None)
119-
if _index:
120-
self._index = _index
121-
logger.info(f"Long term memory index set to {self._index}.")
122-
else:
123-
logger.warning(
124-
"Cannot find index via backend_config, please set `index` parameter."
125-
)
126-
return
127110

128-
if self.app_name and self.user_id:
129-
self._index = build_long_term_memory_index(
130-
app_name=self.app_name, user_id=self.user_id
131-
)
132-
logger.info(f"Long term memory index set to {self._index}.")
133-
self._backend = _get_backend_cls(self.backend)(
134-
index=self._index, **self.backend_config if self.backend_config else {}
135-
)
136-
else:
111+
# Forward compliance
112+
if self.backend == "viking_mem":
137113
logger.warning(
138-
"Neither `backend_instance`, `backend_config`, nor (`app_name`/`user_id`) is provided, the long term memory storage will initialize when adding a session."
114+
"The `viking_mem` backend is deprecated, change to `viking` instead."
139115
)
116+
self.backend = "viking"
117+
118+
self._backend = _get_backend_cls(self.backend)(
119+
index=self.index, **self.backend_config if self.backend_config else {}
120+
)
140121

141122
def _filter_and_convert_events(self, events: list[Event]) -> list[str]:
142123
final_events = []
@@ -164,75 +145,32 @@ async def add_session_to_memory(
164145
self,
165146
session: Session,
166147
):
167-
app_name = session.app_name
168148
user_id = session.user_id
169-
170-
if not self._backend and isinstance(self.backend, str):
171-
self._index = build_long_term_memory_index(app_name, user_id)
172-
self._backend = _get_backend_cls(self.backend)(
173-
index=self._index, **self.backend_config if self.backend_config else {}
174-
)
175-
logger.info(
176-
f"Initialize long term memory backend now, index is {self._index}"
177-
)
178-
179-
if not self._index and self._index != build_long_term_memory_index(
180-
app_name, user_id
181-
):
182-
logger.warning(
183-
f"The `app_name` or `user_id` is different from the initialized one, skip add session to memory. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}"
184-
)
185-
return
186149
event_strings = self._filter_and_convert_events(session.events)
187150

188151
logger.info(
189-
f"Adding {len(event_strings)} events to long term memory: index={self._index}"
152+
f"Adding {len(event_strings)} events to long term memory: index={self.index}"
153+
)
154+
self._backend.save_memory(user_id=user_id, event_strings=event_strings)
155+
logger.info(
156+
f"Added {len(event_strings)} events to long term memory: index={self.index}, user_id={user_id}"
190157
)
191-
192-
if self._backend:
193-
self._backend.save_memory(event_strings=event_strings, user_id=user_id)
194-
195-
logger.info(
196-
f"Added {len(event_strings)} events to long term memory: index={self._index}"
197-
)
198-
else:
199-
logger.error(
200-
"Long term memory backend initialize failed, cannot add session to memory."
201-
)
202158

203159
@override
204-
async def search_memory(self, *, app_name: str, user_id: str, query: str):
205-
# prevent model invoke `load_memory` before add session to this memory
206-
if not self._backend and isinstance(self.backend, str):
207-
self._index = build_long_term_memory_index(app_name, user_id)
208-
self._backend = _get_backend_cls(self.backend)(
209-
index=self._index, **self.backend_config if self.backend_config else {}
210-
)
211-
logger.info(
212-
f"Initialize long term memory backend now, index is {self._index}"
213-
)
160+
async def search_memory(
161+
self, *, app_name: str, user_id: str, query: str
162+
) -> SearchMemoryResponse:
163+
logger.info(f"Search memory with query={query}")
214164

215-
if not self._index and self._index != build_long_term_memory_index(
216-
app_name, user_id
217-
):
218-
logger.warning(
219-
f"The `app_name` or `user_id` is different from the initialized one. Initialized index: {self._index}, current built index: {build_long_term_memory_index(app_name, user_id)}. Search memory return empty list."
165+
memory_chunks = []
166+
try:
167+
memory_chunks = self._backend.search_memory(
168+
query=query, top_k=self.top_k, user_id=user_id
220169
)
221-
return SearchMemoryResponse(memories=[])
222-
223-
if not self._backend:
170+
except Exception as e:
224171
logger.error(
225-
"Long term memory backend is not initialized, cannot search memory."
172+
f"Exception orrcus during memory search: {e}. Return empty memory chunks"
226173
)
227-
return SearchMemoryResponse(memories=[])
228-
229-
logger.info(
230-
f"Searching long term memory: query={query} index={self._index} top_k={self.top_k}"
231-
)
232-
233-
memory_chunks = self._backend.search_memory(
234-
query=query, top_k=self.top_k, user_id=user_id
235-
)
236174

237175
memory_events = []
238176
for memory in memory_chunks:
@@ -260,6 +198,6 @@ async def search_memory(self, *, app_name: str, user_id: str, query: str):
260198
)
261199

262200
logger.info(
263-
f"Return {len(memory_events)} memory events for query: {query} index={self._index}"
201+
f"Return {len(memory_events)} memory events for query: {query} index={self.index} user_id={user_id}"
264202
)
265203
return SearchMemoryResponse(memories=memory_events)

veadk/memory/long_term_memory_backends/base_backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ def precheck_index_naming(self):
2525
"""Check the index name is valid or not"""
2626

2727
@abstractmethod
28-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
28+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
2929
"""Save memory to long term memory backend"""
3030

3131
@abstractmethod
32-
def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
32+
def search_memory(
33+
self, user_id: str, query: str, top_k: int, **kwargs
34+
) -> list[str]:
3335
"""Retrieve memory from long term memory backend"""

veadk/memory/long_term_memory_backends/in_memory_backend.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@ class InMemoryLTMBackend(BaseLongTermMemoryBackend):
2929
embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig)
3030
"""Embedding model configs"""
3131

32-
def precheck_index_naming(self):
33-
# no checking
34-
pass
35-
3632
def model_post_init(self, __context: Any) -> None:
3733
self._embed_model = OpenAILikeEmbedding(
3834
model_name=self.embedding_config.name,
@@ -41,16 +37,22 @@ def model_post_init(self, __context: Any) -> None:
4137
)
4238
self._vector_index = VectorStoreIndex([], embed_model=self._embed_model)
4339

40+
def precheck_index_naming(self):
41+
# no checking
42+
pass
43+
4444
@override
45-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
45+
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
4646
for event_string in event_strings:
4747
document = Document(text=event_string)
4848
nodes = self._split_documents([document])
4949
self._vector_index.insert_nodes(nodes)
5050
return True
5151

5252
@override
53-
def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
53+
def search_memory(
54+
self, user_id: str, query: str, top_k: int, **kwargs
55+
) -> list[str]:
5456
_retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
5557
retrieved_nodes = _retriever.retrieve(query)
5658
return [node.text for node in retrieved_nodes]

veadk/memory/long_term_memory_backends/mem0_backend.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,11 @@
1313
# limitations under the License.
1414

1515
from typing import Any
16-
from typing_extensions import override
16+
1717
from pydantic import Field
18+
from typing_extensions import override
1819

1920
from veadk.configs.database_configs import Mem0Config
20-
21-
2221
from veadk.memory.long_term_memory_backends.base_backend import (
2322
BaseLongTermMemoryBackend,
2423
)
@@ -66,7 +65,9 @@ def precheck_index_naming(self):
6665
pass
6766

6867
@override
69-
def save_memory(self, event_strings: list[str], **kwargs) -> bool:
68+
def save_memory(
69+
self, event_strings: list[str], user_id: str = "default_user", **kwargs
70+
) -> bool:
7071
"""Save memory to Mem0
7172
7273
Args:
@@ -76,8 +77,6 @@ def save_memory(self, event_strings: list[str], **kwargs) -> bool:
7677
Returns:
7778
bool: True if saved successfully, False otherwise
7879
"""
79-
user_id = kwargs.get("user_id", "default_user")
80-
8180
try:
8281
logger.info(
8382
f"Saving {len(event_strings)} events to Mem0 for user: {user_id}"
@@ -100,7 +99,9 @@ def save_memory(self, event_strings: list[str], **kwargs) -> bool:
10099
return False
101100

102101
@override
103-
def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
102+
def search_memory(
103+
self, query: str, top_k: int, user_id: str = "default_user", **kwargs
104+
) -> list[str]:
104105
"""Search memory from Mem0
105106
106107
Args:
@@ -111,7 +112,6 @@ def search_memory(self, query: str, top_k: int, **kwargs) -> list[str]:
111112
Returns:
112113
list[str]: List of memory strings
113114
"""
114-
user_id = kwargs.get("user_id", "default_user")
115115

116116
try:
117117
logger.info(

0 commit comments

Comments
 (0)