Skip to content

Commit dd46e44

Browse files
committed
fix: fix long-term-mem and knowledge base
1 parent bfb38f3 commit dd46e44

18 files changed

+266
-79
lines changed

config.yaml.full

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model:
1717
embedding:
1818
name: doubao-embedding-text-240715
1919
dim: 2560
20-
api_base: https://ark.cn-beijing.volces.com/api/v3/embeddings
20+
api_base: https://ark.cn-beijing.volces.com/api/v3/
2121
api_key:
2222
video:
2323
name: doubao-seedance-1-0-pro-250528

docs/docs/agent.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,12 @@ root_agent:
129129
backend: local
130130
knowledgebase:
131131
backend: opensearch
132+
index: test
132133
tools:
133134
- module: demo_tool # tool 所在的模块
134135
func: greeting # tool 的函数名称
135136
- module: tools.tool
136137
func: count
137-
sub_agents:
138138
sub_agents:
139139
- ${sub_agent_1}
140140

docs/docs/installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ model:
7070
embedding:
7171
name: doubao-embedding-text-240715
7272
dim: 2560
73-
api_base: https://ark.cn-beijing.volces.com/api/v3/embeddings
73+
api_base: https://ark.cn-beijing.volces.com/api/v3/
7474
api_key:
7575

7676
volcengine:

pyproject.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,19 @@ dependencies = [
2828
"cookiecutter>=2.6.0", # For cloud deploy # For OpenSearch database
2929
"opensearch-py==2.8.0",
3030
"omegaconf>=2.3.0", # For agent builder
31+
"llama-index>=0.14.0",
32+
"llama-index-embeddings-openai-like>=0.2.2",
33+
"llama-index-llms-openai-like>=0.5.1",
34+
"llama-index-vector-stores-opensearch>=0.6.1",
35+
"llama-index-vector-stores-redis>=0.6.1",
3136
]
3237

3338
[project.scripts]
3439
veadk = "veadk.cli.cli:veadk"
3540

3641
[project.optional-dependencies]
3742
database = [
38-
"redis>=6.2.0", # For Redis database
43+
"redis>=5.0", # For Redis database
3944
"pymysql>=1.1.1", # For MySQL database
4045
"volcengine>=1.0.193", # For Viking DB
4146
"tos>=2.8.4", # For TOS storage and Viking DB

tests/test_agent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434

3535
def test_agent():
36-
knowledgebase = KnowledgeBase()
36+
knowledgebase = KnowledgeBase(index="test_index", backend="local")
3737
long_term_memory = LongTermMemory(backend="local")
3838
tracer = OpentelemetryTracer()
3939

tests/test_knowledgebase.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@
2121
async def test_knowledgebase():
2222
app_name = "kb_test_app"
2323
key = "Supercalifragilisticexpialidocious"
24-
kb = KnowledgeBase(backend="local")
24+
kb = KnowledgeBase(backend="local", app_name=app_name)
2525
# Attempt to delete any existing data for the app_name before adding new data
26-
kb.add(
27-
data=[f"knowledgebase_id is {key}"],
28-
app_name=app_name,
26+
kb.add_from_text(
27+
text=[f"knowledgebase_id is {key}"],
2928
)
3029
res_list = kb.search(
3130
query="knowledgebase_id",
32-
app_name=app_name,
31+
top_k=1,
3332
)
3433
res = "".join(res_list)
3534
assert key in res, f"Test failed for backend local res is {res}"

tests/test_long_term_memory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727

2828
@pytest.mark.asyncio
2929
async def test_long_term_memory():
30-
long_term_memory = LongTermMemory(backend="local")
30+
long_term_memory = LongTermMemory(
31+
backend="local", app_name=app_name, user_id=user_id
32+
)
3133
agent = Agent(
3234
name="all_name",
3335
model_name="test_model_name",

veadk/knowledgebase/backends/base_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ class BaseKnowledgebaseBackend(ABC, BaseModel):
2121
index: str
2222
"""Index or collection name of the vector storage."""
2323

24+
@abstractmethod
25+
def precheck_index_naming(self):
26+
"""Check the index name is valid or not"""
27+
2428
@abstractmethod
2529
def add_from_directory(self, directory: str, **kwargs) -> bool:
2630
"""Add knowledge from file path to knowledgebase"""

veadk/knowledgebase/backends/in_memory_backend.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ class InMemoryKnowledgeBackend(BaseKnowledgebaseBackend):
2727
embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig)
2828
"""Embedding model configs"""
2929

30+
def precheck_index_naming(self):
31+
# no checking
32+
pass
33+
3034
def model_post_init(self, __context: Any) -> None:
3135
self._embed_model = OpenAILikeEmbedding(
3236
model_name=self.embedding_config.name,
3337
api_key=self.embedding_config.api_key,
3438
api_base=self.embedding_config.api_base,
3539
)
3640
self._vector_index = VectorStoreIndex([], embed_model=self._embed_model)
37-
self._retriever = self._vector_index.as_retriever()
3841

3942
@override
4043
def add_from_directory(self, directory: str) -> bool:
@@ -62,7 +65,8 @@ def add_from_text(self, text: str | list[str]) -> bool:
6265

6366
@override
6467
def search(self, query: str, top_k: int = 5) -> list[str]:
65-
retrieved_nodes = self._retriever.retrieve(query, top_k=top_k)
68+
_retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
69+
retrieved_nodes = _retriever.retrieve(query)
6670
return [node.text for node in retrieved_nodes]
6771

6872
def _split_documents(self, documents: list[Document]) -> list[BaseNode]:

veadk/knowledgebase/backends/opensearch_backend.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415

1516
from llama_index.core import (
1617
Document,
@@ -40,7 +41,19 @@ class OpensearchKnowledgeBackend(BaseKnowledgebaseBackend):
4041
embedding_config: EmbeddingModelConfig = Field(default_factory=EmbeddingModelConfig)
4142
"""Embedding model configs"""
4243

44+
def precheck_index_naming(self):
45+
if not (
46+
isinstance(self.index, str)
47+
and not self.index.startswith(("_", "-"))
48+
and self.index.islower()
49+
and re.match(r"^[a-z0-9_\-.]+$", self.index)
50+
):
51+
raise ValueError(
52+
"The index name does not conform to the naming rules of OpenSearch"
53+
)
54+
4355
def model_post_init(self, __context: Any) -> None:
56+
self.precheck_index_naming()
4457
self._opensearch_client = OpensearchVectorClient(
4558
endpoint=self.opensearch_config.host,
4659
port=self.opensearch_config.port,
@@ -71,7 +84,6 @@ def model_post_init(self, __context: Any) -> None:
7184
storage_context=self._storage_context,
7285
embed_model=self._embed_model,
7386
)
74-
self._retriever = self._vector_index.as_retriever()
7587

7688
@override
7789
def add_from_directory(self, directory: str) -> bool:
@@ -99,12 +111,8 @@ def add_from_text(self, text: str | list[str]) -> bool:
99111

100112
@override
101113
def search(self, query: str, top_k: int = 5) -> list[str]:
102-
_original_top_k = self._retriever.similarity_top_k # type: ignore
103-
self._retriever.similarity_top_k = top_k # type: ignore
104-
105-
retrieved_nodes = self._retriever.retrieve(query)
106-
107-
self._retriever.similarity_top_k = _original_top_k # type: ignore
114+
_retriever = self._vector_index.as_retriever(similarity_top_k=top_k)
115+
retrieved_nodes = _retriever.retrieve(query)
108116
return [node.text for node in retrieved_nodes]
109117

110118
def _split_documents(self, documents: list[Document]) -> list[BaseNode]:

0 commit comments

Comments
 (0)