Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions tests/test_knowledgebase.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,3 @@ async def test_knowledgebase():
)
res = "".join(res_list)
assert key in res, f"Test failed for backend local res is {res}"
assert key in res, f"Test failed for backend local res is {res}"
assert key in res, f"Test failed for backend local res is {res}"
2 changes: 1 addition & 1 deletion tests/test_long_term_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def test_long_term_memory():
events=[
Event(
invocation_id="test_invocation_id",
author=agent.name,
author="user",
branch=None,
content=types.Content(
parts=[types.Part(text="My name is Alice.")],
Expand Down
279 changes: 279 additions & 0 deletions veadk/database/database_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import time
from typing import BinaryIO, TextIO
from veadk.database.base_database import BaseDatabase

from veadk.utils.logger import get_logger

logger = get_logger(__name__)


class KVDatabaseAdapter:
def __init__(self, client):
from veadk.database.kv.redis_database import RedisDatabase

self.client: RedisDatabase = client

def add(self, data: list[str], index: str):
logger.debug(f"Adding documents to Redis database: index={index}")

try:
for _data in data:
self.client.add(key=index, value=_data)
logger.debug(f"Added {len(data)} texts to Redis database: index={index}")
except Exception as e:
logger.error(
f"Failed to add data to Redis database: index={index} error={e}"
)
raise e

def query(self, query: str, index: str, top_k: int = 0) -> list[str]:
logger.debug(f"Querying Redis database: index={index} query={query}")

# ignore top_k, as KV search only return one result
_ = top_k

try:
result = self.client.query(key=index, query=query)
return result
except Exception as e:
logger.error(f"Failed to search from Redis: index={index} error={e}")
raise e


class RelationalDatabaseAdapter:
def __init__(self, client):
from veadk.database.relational.mysql_database import MysqlDatabase

self.client: MysqlDatabase = client

def create_table(self, table_name: str):
logger.debug(f"Creating table for SQL database: table_name={table_name}")

sql = f"""
CREATE TABLE `{table_name}` (
`id` BIGINT AUTO_INCREMENT PRIMARY KEY,
`data` TEXT NOT NULL,
`created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET={self.client.config.charset};
"""
self.client.add(sql)

def add(self, data: list[str], index: str):
logger.debug(
f"Adding documents to SQL database: table_name={index} data_len={len(data)}"
)

if not self.client.table_exists(index):
logger.warning(f"Table {index} does not exist, creating a new table.")
self.create_table(index)

for _data in data:
sql = f"""
INSERT INTO `{index}` (`data`)
VALUES (%s);
"""
self.client.add(sql, params=(_data,))
logger.debug(f"Added {len(data)} texts to table {index}.")

def query(self, query: str, index: str, top_k: int) -> list[str]:
logger.debug(
f"Querying SQL database: table_name={index} query={query} top_k={top_k}"
)

if not self.client.table_exists(index):
logger.warning(
f"Querying SQL database, but table `{index}` does not exist, returning empty list."
)
return []

sql = f"""
SELECT `data` FROM `{index}` ORDER BY `created_at` DESC LIMIT {top_k};
"""
results = self.client.query(sql)

return [item["data"] for item in results]


class VectorDatabaseAdapter:
def __init__(self, client):
from veadk.database.vector.opensearch_vector_database import (
OpenSearchVectorDatabase,
)

self.client: OpenSearchVectorDatabase = client

def _validate_index(self, index: str):
"""
Verify whether the string conforms to the naming rules of index_name in OpenSearch.
https://docs.opensearch.org/2.8/api-reference/index-apis/create-index/
"""
if not (
isinstance(index, str)
and not index.startswith(("_", "-"))
and index.islower()
and re.match(r"^[a-z0-9_\-.]+$", index)
):
raise ValueError(
"The index name does not conform to the naming rules of OpenSearch"
)

def add(self, data: list[str], index: str):
self._validate_index(index)

logger.debug(
f"Adding documents to vector database: index={index} data_len={len(data)}"
)

self.client.add(data, collection_name=index)

def query(self, query: str, index: str, top_k: int) -> list[str]:
logger.debug(
f"Querying vector database: collection_name={index} query={query} top_k={top_k}"
)

return self.client.query(
query=query,
collection_name=index,
top_k=top_k,
)


class VikingDatabaseAdapter:
def __init__(self, client):
from veadk.database.viking.viking_database import VikingDatabase

self.client: VikingDatabase = client

def _validate_index(self, index: str):
"""
Only English letters, numbers, and underscores (_) are allowed.
It must start with an English letter and cannot be empty. Length requirement: [1, 128].
For details, please see: https://www.volcengine.com/docs/84313/1254542?lang=zh
"""
if not (
isinstance(index, str)
and 0 < len(index) <= 128
and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index)
):
raise ValueError(
"The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
)

def get_or_create_collection(self, collection_name: str):
if not self.client.collection_exists(collection_name):
logger.warning(
f"Collection {collection_name} does not exist, creating a new collection."
)
self.client.create_collection(collection_name)

# After creation, it is necessary to wait for a while.
count = 0
while not self.client.collection_exists(collection_name):
print("here")
time.sleep(1)
count += 1
if count > 60:
raise TimeoutError(
f"Collection {collection_name} not created after 50 seconds"
)

def add(
self, data: str | list[str] | TextIO | BinaryIO | bytes, index: str, **kwargs
):
self._validate_index(index)

logger.debug(f"Adding documents to Viking database: collection_name={index}")

self.get_or_create_collection(index)
self.client.add(data, collection_name=index, **kwargs)

def query(self, query: str, index: str, top_k: int) -> list[str]:
self._validate_index(index)

logger.debug(f"Querying Viking database: collection_name={index} query={query}")

if not self.client.collection_exists(index):
return []

return self.client.query(query, collection_name=index, top_k=top_k)


class VikingMemoryDatabaseAdapter:
def __init__(self, client):
from veadk.database.viking.viking_memory_db import VikingMemoryDatabase

self.client: VikingMemoryDatabase = client

def _validate_index(self, index: str):
if not (
isinstance(index, str)
and 1 <= len(index) <= 128
and re.fullmatch(r"^[a-zA-Z][a-zA-Z0-9_]*$", index)
):
raise ValueError(
"The index name does not conform to the rules: it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
)

def add(self, data: list[str], index: str, **kwargs):
self._validate_index(index)

logger.debug(
f"Adding documents to Viking database memory: collection_name={index} data_len={len(data)}"
)

self.client.add(data, collection_name=index, **kwargs)

def query(self, query: str, index: str, top_k: int, **kwargs):
self._validate_index(index)

logger.debug(
f"Querying Viking database memory: collection_name={index} query={query} top_k={top_k}"
)

result = self.client.query(query, collection_name=index, top_k=top_k, **kwargs)
return result


class LocalDatabaseAdapter:
def __init__(self, client):
from veadk.database.local_database import LocalDataBase

self.client: LocalDataBase = client

def add(self, data: list[str], **kwargs):
self.client.add(data)

def query(self, query: str, **kwargs):
return self.client.query(query, **kwargs)


MAPPING = {
"RedisDatabase": KVDatabaseAdapter,
"MysqlDatabase": RelationalDatabaseAdapter,
"LocalDataBase": LocalDatabaseAdapter,
"VikingDatabase": VikingDatabaseAdapter,
"OpenSearchVectorDatabase": VectorDatabaseAdapter,
"VikingMemoryDatabase": VikingMemoryDatabaseAdapter,
}


def get_knowledgebase_database_adapter(database_client: BaseDatabase):
return MAPPING[type(database_client).__name__](client=database_client)


def get_long_term_memory_database_adapter(database_client: BaseDatabase):
return MAPPING[type(database_client).__name__](client=database_client)
6 changes: 3 additions & 3 deletions veadk/database/database_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,12 @@ def create(backend: str, config=None) -> BaseDatabase:
return VikingDatabase() if config is None else VikingDatabase(config=config)

if backend == DatabaseBackend.VIKING_MEM:
from .viking.viking_memory_db import VikingDatabaseMemory
from .viking.viking_memory_db import VikingMemoryDatabase

return (
VikingDatabaseMemory()
VikingMemoryDatabase()
if config is None
else VikingDatabaseMemory(config=config)
else VikingMemoryDatabase(config=config)
)
else:
raise ValueError(f"Unsupported database type: {backend}")
30 changes: 24 additions & 6 deletions veadk/database/viking/viking_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
get_collections_path = "/api/knowledge/collection/info"
doc_add_path = "/api/knowledge/doc/add"
doc_info_path = "/api/knowledge/doc/info"
doc_del_path = "/api/collection/drop"


class VolcengineTOSConfig(BaseModel):
Expand Down Expand Up @@ -215,7 +216,12 @@ def _add_doc(self, collection_name: str, tos_url: str, doc_id: str, **kwargs: An

return doc_id

def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any):
def add(
self,
data: str | list[str] | TextIO | BinaryIO | bytes,
collection_name: str,
**kwargs,
):
"""
Args:
data: str, file path or file stream: Both file or file.read() are acceptable.
Expand All @@ -226,8 +232,6 @@ def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any):
"doc_id": "<doc_id>",
}
"""
collection_name = kwargs.get("collection_name")
assert collection_name is not None, "collection_name is required"

status, tos_url = self._upload_to_tos(data=data, **kwargs)
if status != 200:
Expand All @@ -243,9 +247,23 @@ def add(self, data: str | list[str] | TextIO | BinaryIO | bytes, **kwargs: Any):
}

def delete(self, **kwargs: Any):
# collection_name = kwargs.get("collection_name")
# todo: delete vikingdb
...
collection_name = kwargs.get("collection_name")
resource_id = kwargs.get("resource_id")
request_param = {"collection_name": collection_name, "resource_id": resource_id}
doc_del_req = prepare_request(
method="POST", path=doc_del_path, config=self.config, data=request_param
)
rsp = requests.request(
method=doc_del_req.method,
url="http://{}{}".format(g_knowledge_base_domain, doc_del_req.path),
headers=doc_del_req.headers,
data=doc_del_req.body,
)
result = rsp.json()
if result["code"] != 0:
logger.error(f"Error in add_doc: {result['message']}")
return {"error": result["message"]}
return {}

def query(self, query: str, **kwargs: Any) -> list[str]:
"""
Expand Down
Loading