Skip to content

Commit 57c90dc

Browse files
chore(ltm): optimize memory structure (#235)
* chore(ltm): optimize memory structure * rebase knowledgebase tool * add file header * fix test * fix backend config * fix create session bug * fix memory index bug * fix ltm init log * update viking memory logs * fix log bugs
1 parent c20ee49 commit 57c90dc

File tree

18 files changed

+345
-288
lines changed

18 files changed

+345
-288
lines changed

.gitleaks.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@ description = "Empty environment variables with KEY pattern"
7373
regex = '''os\.environ\[".*?KEY"\]\s*=\s*".+"'''
7474

7575
[allowlist]
76-
paths = ["requirements.txt"]
76+
paths = ["requirements.txt", "tests"]

docs/content/6.memory/3.long-term-memory.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ navigation:
77

88
## 使用方法
99

10-
VeADK 的长期记忆通常存储在数据库中,通过如下方式定义一个长期记忆:
10+
VeADK 的长期记忆通常存储在数据库中,你需要在初始化长期记忆时定义 `index` 来指定后端索引名称。通过如下方式定义一个长期记忆:
1111

1212
```python
1313
from veadk.memory.long_term_memory import LongTermMemory
1414

15-
# 由于长期记忆需要构建索引,因此你必须在初始化长期记忆时定义 `app_name` 以及 `user_id`
16-
long_term_memory = LongTermMemory(app_name="my_app_name", user_id="user_id")
15+
#
16+
long_term_memory = LongTermMemory(index="my_index")
1717
```
1818

1919
通过如下例子说明长期记忆:
@@ -32,7 +32,7 @@ user_id = "temp_user"
3232
teaching_session_id = "teaching_session"
3333
student_session_id = "student_session"
3434

35-
long_term_memory = LongTermMemory(backend="local", app_name=app_name, user_id=user_id)
35+
long_term_memory = LongTermMemory(backend="local", index=app_name)
3636

3737
agent = Agent(long_term_memory=long_term_memory)
3838

@@ -90,8 +90,4 @@ print(response)
9090
::field{name="app_name" type="string"}
9191
Agent 应用名称,用于多应用区分。默认空字符串。
9292
::
93-
94-
::field{name="user_id" type="string"}
95-
Agent 用户 ID,用于区分不同用户的长期记忆。默认空字符串。
96-
::
9793
::

docs/content/7.knowledgebase/1.knowledgebase.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ navigation:
99

1010
VeADK 基于 Llama-index 作为知识库的主要处理入口。开发者可上传文本、文件、目录,我们会为您进行自动切片。
1111

12-
创建知识库时,您必须要提供您的 `app_name`(将会用来自动构建索引名称),或指定一个知识库的索引
12+
创建知识库时,您必须要提供您的知识库后端索引名称 `index`,或指定 `app_name` 来作为索引名称
1313

1414
```python
1515
from veadk.knowledgebase import KnowledgeBase

tests/test_agent.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from unittest.mock import Mock, patch
1617

1718
from google.adk.agents.llm_agent import LlmAgent
@@ -33,11 +34,10 @@
3334

3435

3536
def test_agent():
36-
knowledgebase = KnowledgeBase(
37-
index="test_index",
38-
backend="local",
39-
backend_config={"embedding_config": {"api_key": "test"}},
40-
)
37+
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"
38+
39+
knowledgebase = KnowledgeBase(index="test_index", backend="local")
40+
4141
long_term_memory = LongTermMemory(backend="local")
4242
tracer = OpentelemetryTracer()
4343

@@ -69,8 +69,6 @@ def test_agent():
6969

7070
assert agent.knowledgebase == knowledgebase
7171
assert agent.knowledgebase.backend == "local" # type: ignore
72-
assert load_knowledgebase_tool.knowledgebase == agent.knowledgebase
73-
assert load_knowledgebase_tool.load_knowledgebase_tool in agent.tools
7472

7573
assert agent.long_term_memory.backend == "local" # type: ignore
7674
assert load_memory in agent.tools

tests/test_knowledgebase.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516

1617
import pytest
1718

@@ -21,11 +22,9 @@
2122

2223
@pytest.mark.asyncio
2324
async def test_knowledgebase():
25+
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"
26+
2427
app_name = "kb_test_app"
25-
kb = KnowledgeBase(
26-
backend="local",
27-
app_name=app_name,
28-
backend_config={"embedding_config": {"api_key": "test"}},
29-
)
28+
kb = KnowledgeBase(backend="local", app_name=app_name)
3029

3130
assert isinstance(kb._backend, InMemoryKnowledgeBackend)

tests/test_long_term_memory.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,20 @@
1313
# limitations under the License.
1414

1515

16+
import os
17+
1618
import pytest
1719
from google.adk.tools import load_memory
1820

1921
from veadk.agent import Agent
2022
from veadk.memory.long_term_memory import LongTermMemory
2123

22-
app_name = "test_ltm"
23-
user_id = "test_user"
24-
2524

2625
@pytest.mark.asyncio
2726
async def test_long_term_memory():
28-
long_term_memory = LongTermMemory(
29-
backend="local",
30-
# app_name=app_name,
31-
# user_id=user_id,
32-
)
27+
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"
28+
long_term_memory = LongTermMemory(backend="local")
29+
3330
agent = Agent(
3431
name="all_name",
3532
model_name="test_model_name",
@@ -43,7 +40,8 @@ async def test_long_term_memory():
4340

4441
assert load_memory in agent.tools, "load_memory tool not found in agent tools"
4542

46-
assert not agent.long_term_memory._backend
43+
assert agent.long_term_memory
44+
assert agent.long_term_memory._backend
4745

4846
# assert agent.long_term_memory._backend.index == build_long_term_memory_index(
4947
# app_name, user_id

tests/test_runner.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
1517
from google.genai import types
1618

1719
from veadk.agent import Agent
1820
from veadk.memory.long_term_memory import LongTermMemory
1921
from veadk.memory.short_term_memory import ShortTermMemory
20-
from veadk.runner import Runner
21-
2222

2323
# Import the standalone function instead of accessing as class method
24-
from veadk.runner import _convert_messages
24+
from veadk.runner import Runner, _convert_messages
2525

2626

2727
def _test_convert_messages(runner):
@@ -67,6 +67,8 @@ def _test_convert_messages(runner):
6767

6868
def test_runner():
6969
"""Test Runner class initialization and core properties"""
70+
os.environ["MODEL_EMBEDDING_API_KEY"] = "mocked_api_key"
71+
7072
short_term_memory = ShortTermMemory()
7173
long_term_memory = LongTermMemory(backend="local")
7274
agent = Agent(

veadk/agent.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,14 @@ def model_post_init(self, __context: Any) -> None:
133133
)
134134

135135
if self.knowledgebase:
136-
from veadk.tools import load_knowledgebase_tool
136+
from veadk.tools.builtin_tools.load_knowledgebase import (
137+
LoadKnowledgebaseTool,
138+
)
137139

138-
load_knowledgebase_tool.knowledgebase = self.knowledgebase
139-
self.tools.append(load_knowledgebase_tool.load_knowledgebase_tool)
140+
load_knowledgebase_tool = LoadKnowledgebaseTool(
141+
knowledgebase=self.knowledgebase
142+
)
143+
self.tools.append(load_knowledgebase_tool)
140144

141145
if self.long_term_memory is not None:
142146
from google.adk.tools import load_memory

veadk/knowledgebase/knowledgebase.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Callable, Literal
15+
from __future__ import annotations
16+
17+
from typing import Any, Callable, Literal, Union
1618

1719
from pydantic import BaseModel, Field
18-
from typing_extensions import Union
1920

2021
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
2122
from veadk.knowledgebase.entry import KnowledgebaseEntry
@@ -54,11 +55,11 @@ def _get_backend_cls(backend: str) -> type[BaseKnowledgebaseBackend]:
5455
raise ValueError(f"Unsupported knowledgebase backend: {backend}")
5556

5657

57-
def build_knowledgebase_index(app_name: str):
58-
return f"veadk_kb_{app_name}"
58+
class KnowledgeBase(BaseModel):
59+
name: str = "user_knowledgebase"
5960

61+
description: str = "This knowledgebase stores some user-related information."
6062

61-
class KnowledgeBase(BaseModel):
6263
backend: Union[
6364
Literal["local", "opensearch", "viking", "redis"], BaseKnowledgebaseBackend
6465
] = "local"
@@ -73,9 +74,7 @@ class KnowledgeBase(BaseModel):
7374
"""Configuration for the backend"""
7475

7576
top_k: int = 10
76-
"""Number of top similar documents to retrieve during search.
77-
78-
Default is 10."""
77+
"""Number of top similar documents to retrieve during search"""
7978

8079
app_name: str = ""
8180

@@ -85,38 +84,27 @@ class KnowledgeBase(BaseModel):
8584
def model_post_init(self, __context: Any) -> None:
8685
if isinstance(self.backend, BaseKnowledgebaseBackend):
8786
self._backend = self.backend
87+
self.index = self._backend.index
8888
logger.info(
8989
f"Initialized knowledgebase with provided backend instance {self._backend.__class__.__name__}"
9090
)
9191
return
9292

93-
# must provide at least one of them
94-
if not self.app_name and not self.index:
95-
raise ValueError(
96-
"Either `app_name` or `index` must be provided one of them."
97-
)
98-
99-
# priority use index
100-
if self.app_name and self.index:
101-
logger.warning(
102-
"`app_name` and `index` are both provided, using `index` as the knowledgebase index name."
103-
)
93+
# Once user define backend config, use it directly
94+
if self.backend_config:
95+
self._backend = _get_backend_cls(self.backend)(**self.backend_config)
96+
return
10497

105-
# generate index name if `index` not provided but `app_name` is provided
106-
if self.app_name and not self.index:
107-
self.index = build_knowledgebase_index(self.app_name)
108-
logger.info(
109-
f"Knowledgebase index is set to {self.index} (generated by the app_name: {self.app_name})."
110-
)
98+
self.index = self.index or self.app_name
99+
if not self.index:
100+
raise ValueError("Either `index` or `app_name` must be provided.")
111101

112102
logger.info(
113-
f"Initializing knowledgebase: backend={self.backend} top_k={self.top_k}"
114-
)
115-
self._backend = _get_backend_cls(self.backend)(
116-
index=self.index, **self.backend_config if self.backend_config else {}
103+
f"Initializing knowledgebase: backend={self.backend} index={self.index} top_k={self.top_k}"
117104
)
105+
self._backend = _get_backend_cls(self.backend)(index=self.index)
118106
logger.info(
119-
f"Initialized knowledgebase with backend {self._backend.__class__.__name__}"
107+
f"Initialized knowledgebase with backend {self.backend.__class__.__name__}"
120108
)
121109

122110
def add_from_directory(self, directory: str, **kwargs) -> bool:
@@ -133,8 +121,7 @@ def add_from_text(self, text: str | list[str], **kwargs) -> bool:
133121

134122
def search(self, query: str, top_k: int = 0, **kwargs) -> list[KnowledgebaseEntry]:
135123
"""Search knowledge from knowledgebase"""
136-
if top_k == 0:
137-
top_k = self.top_k
124+
top_k = top_k if top_k != 0 else self.top_k
138125

139126
_entries = self._backend.search(query=query, top_k=top_k, **kwargs)
140127

0 commit comments

Comments
 (0)