Skip to content

Commit d3bef4f

Browse files
committed
fix: knowledgebase add method & list_chunks
1 parent a918d4d commit d3bef4f

File tree

3 files changed

+79
-12
lines changed

3 files changed

+79
-12
lines changed

veadk/database/database_adapter.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, client):
2828

2929
self.client: RedisDatabase = client
3030

31-
def add(self, data: list[str], index: str):
31+
def add(self, data: list[str], index: str, **kwargs):
3232
logger.debug(f"Adding documents to Redis database: index={index}")
3333

3434
try:
@@ -78,7 +78,7 @@ def delete_doc(self, index: str, id: str) -> bool:
7878
)
7979
return False
8080

81-
def list_docs(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]:
81+
def list_chunks(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]:
8282
logger.debug(f"Listing documents from Redis database: index={index}")
8383
try:
8484
# Get all documents from Redis
@@ -111,7 +111,7 @@ def create_table(self, table_name: str):
111111
"""
112112
self.client.add(sql)
113113

114-
def add(self, data: list[str], index: str):
114+
def add(self, data: list[str], index: str, **kwargs):
115115
logger.debug(
116116
f"Adding documents to SQL database: table_name={index} data_len={len(data)}"
117117
)
@@ -203,7 +203,7 @@ def _validate_index(self, index: str):
203203
"The index name does not conform to the naming rules of OpenSearch"
204204
)
205205

206-
def add(self, data: list[str], index: str):
206+
def add(self, data: list[str], index: str, **kwargs):
207207
self._validate_index(index)
208208

209209
logger.debug(
@@ -247,7 +247,7 @@ def delete_doc(self, index: str, id: str) -> bool:
247247
)
248248
return False
249249

250-
def list_docs(self, index: str, offset: int = 0, limit: int = 1000) -> list[dict]:
250+
def list_chunks(self, index: str, offset: int = 0, limit: int = 1000) -> list[dict]:
251251
self._validate_index(index)
252252
logger.debug(f"Listing documents from vector database: index={index}")
253253
return self.client.list_docs(collection_name=index, offset=offset, limit=limit)
@@ -322,6 +322,13 @@ def delete_doc(self, index: str, id: str) -> bool:
322322
logger.debug(f"Deleting documents from vector database: index={index} id={id}")
323323
return self.client.delete_by_id(collection_name=index, id=id)
324324

325+
def list_chunks(self, index: str, offset: int, limit: int) -> list[dict]:
326+
self._validate_index(index)
327+
logger.debug(f"Listing documents from vector database: index={index}")
328+
return self.client.list_chunks(
329+
collection_name=index, offset=offset, limit=limit
330+
)
331+
325332
def list_docs(self, index: str, offset: int, limit: int) -> list[dict]:
326333
self._validate_index(index)
327334
logger.debug(f"Listing documents from vector database: index={index}")
@@ -371,7 +378,7 @@ def delete(self, index: str) -> bool:
371378
def delete_docs(self, index: str, ids: list[int]):
372379
raise NotImplementedError("VikingMemoryDatabase does not support delete_docs")
373380

374-
def list_docs(self, index: str):
381+
def list_chunks(self, index: str):
375382
raise NotImplementedError("VikingMemoryDatabase does not support list_docs")
376383

377384

@@ -393,7 +400,7 @@ def delete(self, index: str) -> bool:
393400
def delete_doc(self, index: str, id: str) -> bool:
394401
return self.client.delete_doc(id)
395402

396-
def list_docs(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]:
403+
def list_chunks(self, index: str, offset: int = 0, limit: int = 100) -> list[dict]:
397404
return self.client.list_docs(offset=offset, limit=limit)
398405

399406

veadk/database/viking/viking_database.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def collection_exists(self, collection_name: str) -> bool:
403403
else:
404404
return False
405405

406-
def list_docs(
406+
def list_chunks(
407407
self, collection_name: str, offset: int = 0, limit: int = -1
408408
) -> list[dict]:
409409
request_params = {
@@ -431,6 +431,9 @@ def list_docs(
431431
logger.error(f"Error in list_docs: {result['message']}")
432432
raise ValueError(f"Error in list_docs: {result['message']}")
433433

434+
if not result["data"]["point_list"]:
435+
return []
436+
434437
data = [
435438
{
436439
"id": res["point_id"],

veadk/knowledgebase/knowledgebase.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import io
15+
import os.path
1516
from typing import Any, BinaryIO, Literal, TextIO
1617

1718
from pydantic import BaseModel
1819

1920
from veadk.database.database_adapter import get_knowledgebase_database_adapter
2021
from veadk.database.database_factory import DatabaseFactory
22+
from veadk.utils.misc import formatted_timestamp
2123
from veadk.utils.logger import get_logger
2224

2325
logger = get_logger(__name__)
@@ -66,10 +68,65 @@ def add(
6668
)
6769

6870
index = build_knowledgebase_index(app_name)
69-
7071
logger.info(f"Adding documents to knowledgebase: index={index}")
7172

72-
self._adapter.add(data=data, index=index)
73+
if self.backend == "viking":
74+
# Case 1: Handling file paths or lists of file paths (str)
75+
if isinstance(data, str) and os.path.isfile(data):
76+
# 单个文件路径,直接调用client.add
77+
# 获取文件名(包括后缀名)
78+
if "file_name" not in kwargs or not kwargs["file_name"]:
79+
kwargs["file_name"] = os.path.basename(data)
80+
return self._adapter.add(data=data, index=index, **kwargs)
81+
# Case 2: Handling when list[str] is a full path (list[str])
82+
if isinstance(data, list):
83+
if all(isinstance(item, str) for item in data):
84+
all_paths = all(os.path.isfile(item) for item in data)
85+
all_not_paths = all(not os.path.isfile(item) for item in data)
86+
if all_paths:
87+
if "file_name" not in kwargs or not kwargs["file_name"]:
88+
kwargs["file_name"] = [
89+
os.path.basename(item) for item in data
90+
]
91+
return self._adapter.add(data=data, index=index, **kwargs)
92+
elif (
93+
not all_not_paths
94+
): # Prevent the occurrence of non-existent paths
95+
# There is a mixture of paths and non-paths
96+
raise ValueError(
97+
"Mixed file paths and content strings in list are not allowed"
98+
)
99+
# Case 3: Handling strings or string arrays (content) (str or list[str])
100+
if isinstance(data, str) or (
101+
isinstance(data, list) and all(isinstance(item, str) for item in data)
102+
):
103+
if "file_name" not in kwargs or not kwargs["file_name"]:
104+
if isinstance(data, str):
105+
kwargs["file_name"] = f"{formatted_timestamp()}.txt"
106+
else: # list[str] without file_names
107+
prefix_file_name = formatted_timestamp()
108+
kwargs["file_name"] = [
109+
f"{prefix_file_name}_{i}.txt" for i in range(len(data))
110+
]
111+
return self._adapter.add(data=data, index=index, **kwargs)
112+
113+
# Case 4: Handling binary data (bytes)
114+
if isinstance(data, bytes):
115+
# user must give file_name
116+
if "file_name" not in kwargs:
117+
raise ValueError("file_name must be provided for binary data")
118+
return self._adapter.add(data=data, index=index, **kwargs)
119+
120+
# Case 5: Handling file objects TextIO or BinaryIO
121+
if isinstance(data, (io.TextIOWrapper, io.BufferedReader)):
122+
if not kwargs.get("file_name") and hasattr(data, "name"):
123+
kwargs["file_name"] = os.path.basename(data.name)
124+
return self._adapter.add(data=data, index=index, **kwargs)
125+
# Case6: Unsupported data type
126+
raise TypeError(f"Unsupported data type: {type(data)}")
127+
128+
# not viking
129+
return self._adapter.add(data=data, index=index, **kwargs)
73130

74131
def search(self, query: str, app_name: str, top_k: int | None = None) -> list[str]:
75132
top_k = self.top_k if top_k is None else top_k
@@ -93,4 +150,4 @@ def delete_doc(self, app_name: str, id: str) -> bool:
93150

94151
def list_docs(self, app_name: str, offset: int = 0, limit: int = 100) -> list[dict]:
95152
index = build_knowledgebase_index(app_name)
96-
return self._adapter.list_docs(index=index, offset=offset, limit=limit)
153+
return self._adapter.list_chunks(index=index, offset=offset, limit=limit)

0 commit comments

Comments
 (0)