Skip to content

Commit 05a0b57

Browse files
refine(kb): reconstruct knowledgebase (#165)
* refine(kb): reconstruct knowledgebase * add ltm * fix ltm checks * update short term memory * add callback mount * add db auth * remove database dir * modify default local path * fix bugs * fix: fix long-term-mem and knowledge base * fix: short-term-mem test * fix: short-term-mem postgresql_backend * fix: add list_docs and list_chunks part * fix: typechecking and extensions * fix import issues * fix tests * add configs * fix: rebase agent.md * fix: rebase agent.md sub_agents * fix: config --------- Co-authored-by: hanzhi.421 <[email protected]>
1 parent db3f460 commit 05a0b57

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2090
-2831
lines changed

config.yaml.full

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model:
1717
embedding:
1818
name: doubao-embedding-text-240715
1919
dim: 2560
20-
api_base: https://ark.cn-beijing.volces.com/api/v3/embeddings
20+
api_base: https://ark.cn-beijing.volces.com/api/v3/
2121
api_key:
2222
video:
2323
name: doubao-seedance-1-0-pro-250528

pyproject.toml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,33 @@ dependencies = [
2222
"opentelemetry-instrumentation-logging>=0.56b0",
2323
"wrapt>=1.17.2", # For patching built-in functions
2424
"openai<1.100", # For fix https://github.com/BerriAI/litellm/issues/13710
25-
"volcengine-python-sdk==4.0.3", # For Volcengine API
25+
"volcengine-python-sdk>=4.0.3", # For Volcengine API
26+
"volcengine>=1.0.193", # For Volcengine sign
2627
"agent-pilot-sdk>=0.0.9", # Prompt optimization by Volcengine AgentPilot/PromptPilot toolkits
2728
"fastmcp>=2.11.3", # For running MCP
28-
"cookiecutter>=2.6.0", # For cloud deploy # For OpenSearch database
29-
"opensearch-py==2.8.0",
29+
"cookiecutter>=2.6.0", # For cloud deploy
3030
"omegaconf>=2.3.0", # For agent builder
31+
"llama-index>=0.14.0",
32+
"llama-index-embeddings-openai-like>=0.2.2",
33+
"llama-index-llms-openai-like>=0.5.1",
34+
"llama-index-vector-stores-opensearch>=0.6.1",
35+
"psycopg2-binary>=2.9.10", # For PostgreSQL database (short term memory)
36+
"pymysql>=1.1.1", # For MySQL database (short term memory)
37+
"opensearch-py==2.8.0",
3138
]
3239

3340
[project.scripts]
3441
veadk = "veadk.cli.cli:veadk"
3542

3643
[project.optional-dependencies]
44+
extensions = [
45+
"redis>=5.0", # For Redis database
46+
"tos>=2.8.4", # For TOS storage and Viking DB
47+
"llama-index-vector-stores-redis>=0.6.1",
48+
"mcp-server-vikingdb-memory",
49+
]
3750
database = [
38-
"redis>=6.2.0", # For Redis database
51+
"redis>=5.0", # For Redis database
3952
"pymysql>=1.1.1", # For MySQL database
4053
"volcengine>=1.0.193", # For Viking DB
4154
"tos>=2.8.4", # For TOS storage and Viking DB
@@ -78,3 +91,6 @@ exclude = [
7891
"veadk/integrations/ve_faas/template/*",
7992
"veadk/integrations/ve_faas/web_template/*"
8093
]
94+
95+
[tool.uv.sources]
96+
mcp-server-vikingdb-memory = { git = "https://github.com/volcengine/mcp-server", subdirectory = "server/mcp_server_vikingdb_memory" }

tests/test_agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333

3434

3535
def test_agent():
36-
knowledgebase = KnowledgeBase()
36+
knowledgebase = KnowledgeBase(
37+
index="test_index",
38+
backend="local",
39+
backend_config={"embedding_config": {"api_key": "test"}},
40+
)
3741
long_term_memory = LongTermMemory(backend="local")
3842
tracer = OpentelemetryTracer()
3943

tests/test_knowledgebase.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import pytest
1617

1718
from veadk.knowledgebase import KnowledgeBase
19+
from veadk.knowledgebase.backends.in_memory_backend import InMemoryKnowledgeBackend
1820

1921

2022
@pytest.mark.asyncio
2123
async def test_knowledgebase():
2224
app_name = "kb_test_app"
23-
key = "Supercalifragilisticexpialidocious"
24-
kb = KnowledgeBase(backend="local")
25-
# Attempt to delete any existing data for the app_name before adding new data
26-
kb.add(
27-
data=[f"knowledgebase_id is {key}"],
28-
app_name=app_name,
29-
)
30-
res_list = kb.search(
31-
query="knowledgebase_id",
25+
kb = KnowledgeBase(
26+
backend="local",
3227
app_name=app_name,
28+
backend_config={"embedding_config": {"api_key": "test"}},
3329
)
34-
res = "".join(res_list)
35-
assert key in res, f"Test failed for backend local res is {res}"
30+
31+
assert isinstance(kb._backend, InMemoryKnowledgeBackend)

tests/test_long_term_memory.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import pytest
16-
from google.adk.events import Event
17-
from google.adk.sessions import Session
1817
from google.adk.tools import load_memory
19-
from google.genai import types
2018

2119
from veadk.agent import Agent
2220
from veadk.memory.long_term_memory import LongTermMemory
@@ -27,7 +25,11 @@
2725

2826
@pytest.mark.asyncio
2927
async def test_long_term_memory():
30-
long_term_memory = LongTermMemory(backend="local")
28+
long_term_memory = LongTermMemory(
29+
backend="local",
30+
# app_name=app_name,
31+
# user_id=user_id,
32+
)
3133
agent = Agent(
3234
name="all_name",
3335
model_name="test_model_name",
@@ -41,31 +43,8 @@ async def test_long_term_memory():
4143

4244
assert load_memory in agent.tools, "load_memory tool not found in agent tools"
4345

44-
# mock session
45-
session = Session(
46-
id="test_session_id",
47-
app_name=app_name,
48-
user_id=user_id,
49-
events=[
50-
Event(
51-
invocation_id="test_invocation_id",
52-
author="user",
53-
branch=None,
54-
content=types.Content(
55-
parts=[types.Part(text="My name is Alice.")],
56-
role="user",
57-
),
58-
)
59-
],
60-
)
61-
62-
await long_term_memory.add_session_to_memory(session)
46+
assert not agent.long_term_memory._backend
6347

64-
memories = await long_term_memory.search_memory(
65-
app_name=app_name,
66-
user_id=user_id,
67-
query="Alice",
68-
)
69-
assert (
70-
"Alice" in memories.model_dump()["memories"][0]["content"]["parts"][0]["text"]
71-
)
48+
# assert agent.long_term_memory._backend.index == build_long_term_memory_index(
49+
# app_name, user_id
50+
# )

tests/test_short_term_memory.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import asyncio
1616
import os
1717

18-
import veadk.memory.short_term_memory
1918
from veadk.memory.short_term_memory import ShortTermMemory
2019
from veadk.utils.misc import formatted_timestamp
2120

@@ -35,11 +34,11 @@ def test_short_term_memory():
3534
)
3635
assert session is not None
3736

38-
# database - local
39-
veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH = (
40-
f"/tmp/tmp_for_test_{formatted_timestamp()}.db"
37+
# sqlite
38+
memory = ShortTermMemory(
39+
backend="sqlite",
40+
local_database_path=f"/tmp/tmp_for_test_{formatted_timestamp()}.db",
4141
)
42-
memory = ShortTermMemory(backend="database")
4342
asyncio.run(
4443
memory.session_service.create_session(
4544
app_name="app", user_id="user", session_id="session"
@@ -51,5 +50,5 @@ def test_short_term_memory():
5150
)
5251
)
5352
assert session is not None
54-
assert os.path.exists(veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH)
55-
os.remove(veadk.memory.short_term_memory.DEFAULT_LOCAL_DATABASE_PATH)
53+
assert os.path.exists(memory.local_database_path)
54+
os.remove(memory.local_database_path)

veadk/agent.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ async def run(
196196
collect_runtime_data: bool = False,
197197
eval_set_id: str = "",
198198
save_session_to_memory: bool = False,
199-
enable_memory_optimization: bool = False,
200199
):
201200
"""Running the agent. The runner and session service will be created automatically.
202201
@@ -226,7 +225,6 @@ async def run(
226225
# memory service
227226
short_term_memory = ShortTermMemory(
228227
backend="database" if load_history_sessions_from_db else "local",
229-
enable_memory_optimization=enable_memory_optimization,
230228
db_url=db_url,
231229
)
232230
session_service = short_term_memory.session_service
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
import os
30+
31+
from typing_extensions import override
32+
33+
from veadk.auth.veauth.base_veauth import BaseVeAuth
34+
from veadk.utils.logger import get_logger
35+
36+
# from veadk.utils.volcengine_sign import ve_request
37+
38+
logger = get_logger(__name__)
39+
40+
41+
class OpensearchVeAuth(BaseVeAuth):
42+
def __init__(
43+
self,
44+
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
45+
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
46+
) -> None:
47+
super().__init__(access_key, secret_key)
48+
49+
self._token: str = ""
50+
51+
@override
52+
def _fetch_token(self) -> None:
53+
logger.info("Fetching Opensearch STS token...")
54+
55+
# res = ve_request(
56+
# request_body={},
57+
# action="GetOrCreatePromptPilotAPIKeys",
58+
# ak=self.access_key,
59+
# sk=self.secret_key,
60+
# service="ark",
61+
# version="2024-01-01",
62+
# region="cn-beijing",
63+
# host="open.volcengineapi.com",
64+
# )
65+
# try:
66+
# self._token = res["Result"]["APIKeys"][0]["APIKey"]
67+
# except KeyError:
68+
# raise ValueError(f"Failed to get Prompt Pilot token: {res}")
69+
70+
@property
71+
def token(self) -> str:
72+
if self._token:
73+
return self._token
74+
self._fetch_token()
75+
return self._token
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
16+
#
17+
# Licensed under the Apache License, Version 2.0 (the "License");
18+
# you may not use this file except in compliance with the License.
19+
# You may obtain a copy of the License at
20+
#
21+
# http://www.apache.org/licenses/LICENSE-2.0
22+
#
23+
# Unless required by applicable law or agreed to in writing, software
24+
# distributed under the License is distributed on an "AS IS" BASIS,
25+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26+
# See the License for the specific language governing permissions and
27+
# limitations under the License.
28+
29+
import os
30+
31+
from typing_extensions import override
32+
33+
from veadk.auth.veauth.base_veauth import BaseVeAuth
34+
from veadk.utils.logger import get_logger
35+
36+
# from veadk.utils.volcengine_sign import ve_request
37+
38+
logger = get_logger(__name__)
39+
40+
41+
class PostgreSqlVeAuth(BaseVeAuth):
42+
def __init__(
43+
self,
44+
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
45+
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
46+
) -> None:
47+
super().__init__(access_key, secret_key)
48+
49+
self._token: str = ""
50+
51+
@override
52+
def _fetch_token(self) -> None:
53+
logger.info("Fetching PostgreSQL STS token...")
54+
55+
# res = ve_request(
56+
# request_body={},
57+
# action="GetOrCreatePromptPilotAPIKeys",
58+
# ak=self.access_key,
59+
# sk=self.secret_key,
60+
# service="ark",
61+
# version="2024-01-01",
62+
# region="cn-beijing",
63+
# host="open.volcengineapi.com",
64+
# )
65+
# try:
66+
# self._token = res["Result"]["APIKeys"][0]["APIKey"]
67+
# except KeyError:
68+
# raise ValueError(f"Failed to get Prompt Pilot token: {res}")
69+
70+
@property
71+
def token(self) -> str:
72+
if self._token:
73+
return self._token
74+
self._fetch_token()
75+
return self._token

0 commit comments

Comments
 (0)