Skip to content

Commit 76e99ce

Browse files
committed
fix: knowledgebase add method
1 parent d3bef4f commit 76e99ce

File tree

2 files changed

+180
-35
lines changed

2 files changed

+180
-35
lines changed

veadk/database/viking/viking_database.py

Lines changed: 165 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,25 @@ def _upload_to_tos(
136136
self,
137137
data: str | list[str] | TextIO | BinaryIO | bytes,
138138
**kwargs: Any,
139-
):
140-
file_ext = kwargs.get(
141-
"file_ext", ".pdf"
142-
) # when bytes data, file_ext is required
139+
) -> tuple[int, str]:
140+
"""
141+
Upload data to TOS (Tinder Object Storage).
143142
143+
Args:
144+
data: The data to be uploaded. Can be one of the following types:
145+
- str: File path or string data
146+
- list[str]: List of strings
147+
- TextIO: File object (text)
148+
- BinaryIO: File object (binary)
149+
- bytes: Binary data
150+
**kwargs: Additional keyword arguments.
151+
- file_name (str): The file name (including suffix).
152+
153+
Returns:
154+
tuple: A tuple containing the status code and TOS URL.
155+
- status_code (int): HTTP status code
156+
- tos_url (str): The URL of the uploaded file in TOS
157+
"""
144158
ak = self.config.volcengine_ak
145159
sk = self.config.volcengine_sk
146160

@@ -151,21 +165,31 @@ def _upload_to_tos(
151165

152166
client = tos.TosClientV2(ak, sk, tos_endpoint, tos_region, max_connections=1024)
153167

168+
# Extract file_name from kwargs - this is now required and includes the extension
169+
file_names = kwargs.get("file_name")
170+
154171
if isinstance(data, str) and os.path.isfile(data): # Process file path
155-
file_ext = os.path.splitext(data)[1]
156-
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"
172+
# Use provided file_name which includes the extension
173+
new_key = f"{tos_key}/{file_names}"
157174
with open(data, "rb") as f:
158175
upload_data = f.read()
159176

177+
elif (
178+
isinstance(data, list)
179+
and all(isinstance(item, str) for item in data)
180+
and all(os.path.isfile(item) for item in data)
181+
):
182+
# Process list of file paths - this should be handled at a higher level
183+
raise ValueError(
184+
"Uploading multiple files through a list of file paths is not supported in _upload_to_tos directly. Please call this function for each file individually."
185+
)
186+
160187
elif isinstance(
161188
data,
162189
(io.TextIOWrapper, io.BufferedReader), # file type: TextIO | BinaryIO
163190
): # Process file stream
164-
# Try to get the file extension from the file name, and use the default value if there is none
165-
file_ext = ".unknown"
166-
if hasattr(data, "name"):
167-
_, file_ext = os.path.splitext(data.name)
168-
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"
191+
# Use provided file_name which includes the extension
192+
new_key = f"{tos_key}/{file_names}"
169193
if isinstance(data, TextIO):
170194
# Encode the text stream content into bytes
171195
upload_data = data.read().encode("utf-8")
@@ -174,16 +198,19 @@ def _upload_to_tos(
174198
upload_data = data.read()
175199

176200
elif isinstance(data, str): # Process ordinary strings
177-
new_key = f"{tos_key}/{str(uuid.uuid4())}.txt"
201+
# Use provided file_name which includes the extension
202+
new_key = f"{tos_key}/{file_names}"
178203
upload_data = data.encode("utf-8") # Encode as byte type
179204

180205
elif isinstance(data, list): # Process list of strings
181-
new_key = f"{tos_key}/{str(uuid.uuid4())}.txt"
206+
# Use provided file_name which includes the extension
207+
new_key = f"{tos_key}/{file_names}"
182208
# Join the strings in the list with newlines and encode as byte type
183209
upload_data = "\n".join(data).encode("utf-8")
184210

185211
elif isinstance(data, bytes): # Process bytes data
186-
new_key = f"{tos_key}/{str(uuid.uuid4())}{file_ext}"
212+
# Use provided file_name which includes the extension
213+
new_key = f"{tos_key}/{file_names}"
187214
upload_data = data
188215

189216
else:
@@ -231,28 +258,136 @@ def add(
231258
**kwargs,
232259
):
233260
"""
261+
Add documents to the Viking database.
234262
Args:
235-
data: str, file path or file stream: Both file or file.read() are acceptable.
236-
**kwargs: collection_name(required)
263+
data: The data to be added. Can be one of the following types:
264+
- str: File path or string data
265+
- list[str]: List of file paths or list of strings
266+
- TextIO: File object (text)
267+
- BinaryIO: File object (binary)
268+
- bytes: Binary data
269+
collection_name: The name of the collection to add documents to.
270+
**kwargs: Additional keyword arguments.
271+
- file_name (str | list[str]): The file name or a list of file names (including suffix).
272+
- doc_id (str): The document ID. If not provided, a UUID will be generated.
237273
Returns:
238-
{
274+
dict or list: A dictionary containing the TOS URL and document ID, or a list of such dictionaries for multiple file uploads.
275+
Format: {
239276
"tos_url": "tos://<bucket>/<key>",
240277
"doc_id": "<doc_id>",
241278
}
242279
"""
243-
244-
status, tos_url = self._upload_to_tos(data=data, **kwargs)
245-
if status != 200:
246-
raise ValueError(f"Error in upload_to_tos: {status}")
247-
doc_id = self._add_doc(
248-
collection_name=collection_name,
249-
tos_url=tos_url,
250-
doc_id=str(uuid.uuid4()),
251-
)
252-
return {
253-
"tos_url": f"tos://{tos_url}",
254-
"doc_id": doc_id,
255-
}
280+
# Handle list of file paths (multiple file upload)
281+
if (
282+
isinstance(data, list)
283+
and all(isinstance(item, str) for item in data)
284+
and all(os.path.isfile(item) for item in data)
285+
):
286+
# Handle multiple file upload
287+
file_names = kwargs.get("file_name")
288+
if (
289+
not file_names
290+
or not isinstance(file_names, list)
291+
or len(file_names) != len(data)
292+
):
293+
raise ValueError(
294+
"For multiple file upload, file_name must be provided as a list with the same length as data"
295+
)
296+
297+
results = []
298+
for i, file_path in enumerate(data):
299+
# Create kwargs for this specific file
300+
single_kwargs = kwargs.copy()
301+
single_kwargs["file_name"] = file_names[i]
302+
303+
# Generate or use provided doc_id for this file
304+
doc_id = single_kwargs.get("doc_id")
305+
if not doc_id:
306+
doc_id = str(uuid.uuid4())
307+
single_kwargs["doc_id"] = doc_id
308+
309+
status, tos_url = self._upload_to_tos(data=file_path, **single_kwargs)
310+
if status != 200:
311+
raise ValueError(
312+
f"Error in upload_to_tos for file {file_path}: {status}"
313+
)
314+
315+
doc_id = self._add_doc(
316+
collection_name=collection_name,
317+
tos_url=tos_url,
318+
doc_id=doc_id,
319+
)
320+
321+
results.append(
322+
{
323+
"tos_url": f"tos://{tos_url}",
324+
"doc_id": doc_id,
325+
}
326+
)
327+
328+
return results
329+
330+
# Handle list of strings (multiple string upload)
331+
elif isinstance(data, list) and all(isinstance(item, str) for item in data):
332+
# Handle multiple string upload
333+
file_names = kwargs.get("file_name")
334+
if (
335+
not file_names
336+
or not isinstance(file_names, list)
337+
or len(file_names) != len(data)
338+
):
339+
raise ValueError(
340+
"For multiple string upload, file_name must be provided as a list with the same length as data"
341+
)
342+
343+
results = []
344+
for i, content in enumerate(data):
345+
# Create kwargs for this specific string
346+
single_kwargs = kwargs.copy()
347+
single_kwargs["file_name"] = file_names[i]
348+
349+
# Generate or use provided doc_id for this string
350+
doc_id = single_kwargs.get("doc_id")
351+
if not doc_id:
352+
doc_id = str(uuid.uuid4())
353+
single_kwargs["doc_id"] = doc_id
354+
355+
status, tos_url = self._upload_to_tos(data=content, **single_kwargs)
356+
if status != 200:
357+
raise ValueError(f"Error in upload_to_tos for string {i}: {status}")
358+
359+
doc_id = self._add_doc(
360+
collection_name=collection_name,
361+
tos_url=tos_url,
362+
doc_id=doc_id,
363+
)
364+
365+
results.append(
366+
{
367+
"tos_url": f"tos://{tos_url}",
368+
"doc_id": doc_id,
369+
}
370+
)
371+
372+
return results
373+
374+
# Handle single file upload or other data types
375+
else:
376+
# Handle doc_id from kwargs or generate a new one
377+
doc_id = kwargs.get("doc_id", str(uuid.uuid4()))
378+
379+
status, tos_url = self._upload_to_tos(data=data, **kwargs)
380+
if status != 200:
381+
raise ValueError(f"Error in upload_to_tos: {status}")
382+
doc_id = self._add_doc(
383+
collection_name=collection_name,
384+
tos_url=tos_url,
385+
doc_id=doc_id,
386+
)
387+
return {
388+
"tos_url": f"tos://{tos_url}",
389+
"doc_id": doc_id,
390+
}
256391

257392
def delete(self, **kwargs: Any):
258393
name = kwargs.get("name")

veadk/knowledgebase/knowledgebase.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,16 @@ def add(
5656
):
5757
"""
5858
Add documents to the vector database.
59-
You can only upload files or file characters when the adapter type used is vikingdb.
60-
In addition, if you upload data of the bytes type,
61-
for example, if you read the file stream of a pdf, then you need to pass an additional parameter file_ext = '.pdf'.
59+
Args:
60+
data (str | list[str] | TextIO | BinaryIO | bytes): The data to be added.
61+
- str: A single file path. (viking only)
62+
- list[str]: A list of file paths.
63+
- TextIO: A file object (TextIO). (viking only) file descriptor
64+
- BinaryIO: A file object (BinaryIO). (viking only) file descriptor
65+
- bytes: Binary data. (viking only) binary data (f.read())
66+
app_name: index name
67+
**kwargs: Additional keyword arguments.
68+
- file_name (str | list[str]): The file name or a list of file names (including suffix). (viking only)
6269
"""
6370
if self.backend != "viking" and not (
6471
isinstance(data, str) or isinstance(data, list)
@@ -73,8 +80,7 @@ def add(
7380
if self.backend == "viking":
7481
# Case 1: Handling file paths or lists of file paths (str)
7582
if isinstance(data, str) and os.path.isfile(data):
76-
# 单个文件路径,直接调用client.add
77-
# 获取文件名(包括后缀名)
83+
# Get the file name (including the suffix)
7884
if "file_name" not in kwargs or not kwargs["file_name"]:
7985
kwargs["file_name"] = os.path.basename(data)
8086
return self._adapter.add(data=data, index=index, **kwargs)
@@ -125,6 +131,10 @@ def add(
125131
# Case6: Unsupported data type
126132
raise TypeError(f"Unsupported data type: {type(data)}")
127133

134+
if isinstance(data, list):
135+
raise TypeError(
136+
f"Unsupported data type: {type(data)}, Only viking support file_path and file bytes"
137+
)
128138
# not viking
129139
return self._adapter.add(data=data, index=index, **kwargs)
130140

0 commit comments

Comments
 (0)