Skip to content

Commit a1b8358

Browse files
committed
feat(kb): add tests for Viking knowledgebase and improve TOS bucket handling
1 parent 57b2da7 commit a1b8358

File tree

3 files changed

+88
-41
lines changed

3 files changed

+88
-41
lines changed

tests/test_knowledgebase.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,13 @@ async def test_knowledgebase():
2929
)
3030

3131
assert isinstance(kb._backend, InMemoryKnowledgeBackend)
32+
33+
34+
@pytest.mark.asyncio
35+
async def test_viking_knowledgebase_add_texts():
36+
app_name = "kb_test_app"
37+
kb = KnowledgeBase(
38+
backend="viking",
39+
app_name=app_name,
40+
)
41+
assert kb.add_from_text(text="test text", tag="tag") is True

veadk/integrations/ve_tos/ve_tos.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@ def __init__(
3636
ak: str = "",
3737
sk: str = "",
3838
region: str = "cn-beijing",
39-
bucket_name: str = DEFAULT_TOS_BUCKET_NAME,
39+
bucket_name: str = "",
4040
) -> None:
4141
self.ak = ak if ak else os.getenv("VOLCENGINE_ACCESS_KEY", "")
4242
self.sk = sk if sk else os.getenv("VOLCENGINE_SECRET_KEY", "")
4343
self.region = region
44-
self.bucket_name = (
45-
bucket_name if bucket_name else getenv("", DEFAULT_TOS_BUCKET_NAME)
44+
self.bucket_name = bucket_name or getenv(
45+
"DATABASE_TOS_BUCKET", DEFAULT_TOS_BUCKET_NAME
4646
)
47+
4748
self._tos_module = None
4849

4950
try:

veadk/knowledgebase/backends/vikingdb_knowledge_backend.py

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import veadk.config # noqa E401
2525
from veadk.config import getenv
2626
from veadk.configs.database_configs import NormalTOSConfig, TOSConfig
27-
from veadk.consts import DEFAULT_TOS_BUCKET_NAME
2827
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
2928
from veadk.knowledgebase.backends.utils import build_vikingdb_knowledgebase_request
3029
from veadk.utils.logger import get_logger
@@ -48,13 +47,6 @@ def _read_file_to_bytes(file_path: str) -> tuple[bytes, str]:
4847
return file_content, file_name
4948

5049

51-
def _extract_tos_attributes(**kwargs) -> tuple[str, str]:
52-
"""Extract TOS attributes from kwargs"""
53-
tos_bucket_name = kwargs.get("tos_bucket_name", DEFAULT_TOS_BUCKET_NAME)
54-
tos_bucket_path = kwargs.get("tos_bucket_path", "knowledgebase")
55-
return tos_bucket_name, tos_bucket_path
56-
57-
5850
def get_files_in_directory(directory: str):
5951
dir_path = Path(directory)
6052
if not dir_path.is_dir():
@@ -109,15 +101,24 @@ def model_post_init(self, __context: Any) -> None:
109101
)
110102

111103
@override
112-
def add_from_directory(self, directory: str, **kwargs) -> bool:
113-
"""
104+
def add_from_directory(
105+
self,
106+
directory: str,
107+
tos_bucket_name: str | None = None,
108+
tos_bucket_path: str = "knowledgebase",
109+
**kwargs,
110+
) -> bool:
111+
"""Add knowledge from a directory to the knowledgebase.
112+
114113
Args:
115-
directory: str, the directory to add to knowledgebase
116-
**kwargs:
117-
- tos_bucket_name: str, the bucket name of TOS
118-
- tos_bucket_path: str, the path of TOS bucket
114+
directory (str): The directory to add to knowledgebase.
115+
tos_bucket_name (str | None, optional): The bucket name of TOS. Defaults to None.
116+
tos_bucket_path (str, optional): The path of TOS bucket. Defaults to "knowledgebase".
117+
118+
Returns:
119+
bool: True if successful, False otherwise.
119120
"""
120-
tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs)
121+
tos_bucket_name = tos_bucket_name or self.tos_config.bucket
121122
files = get_files_in_directory(directory=directory)
122123
for _file in files:
123124
content, file_name = _read_file_to_bytes(_file)
@@ -130,15 +131,24 @@ def add_from_directory(self, directory: str, **kwargs) -> bool:
130131
return True
131132

132133
@override
133-
def add_from_files(self, files: list[str], **kwargs) -> bool:
134-
"""
134+
def add_from_files(
135+
self,
136+
files: list[str],
137+
tos_bucket_name: str | None = None,
138+
tos_bucket_path: str = "knowledgebase",
139+
**kwargs,
140+
) -> bool:
141+
"""Add knowledge from a directory to the knowledgebase.
142+
135143
Args:
136-
files: list[str], the files to add to knowledgebase
137-
**kwargs:
138-
- tos_bucket_name: str, the bucket name of TOS
139-
- tos_bucket_path: str, the path of TOS bucket
144+
files (list[str]): The files to add to knowledgebase.
145+
tos_bucket_name (str | None, optional): The bucket name of TOS. Defaults to None.
146+
tos_bucket_path (str, optional): The path of TOS bucket. Defaults to "knowledgebase".
147+
148+
Returns:
149+
bool: True if successful, False otherwise.
140150
"""
141-
tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs)
151+
tos_bucket_name = tos_bucket_name or self.tos_config.bucket
142152
for _file in files:
143153
content, file_name = _read_file_to_bytes(_file)
144154
tos_url = self._upload_bytes_to_tos(
@@ -150,15 +160,24 @@ def add_from_files(self, files: list[str], **kwargs) -> bool:
150160
return True
151161

152162
@override
153-
def add_from_text(self, text: str | list[str], **kwargs) -> bool:
154-
"""
163+
def add_from_text(
164+
self,
165+
text: str | list[str],
166+
tos_bucket_name: str | None = None,
167+
tos_bucket_path: str = "knowledgebase",
168+
**kwargs,
169+
) -> bool:
170+
"""Add knowledge from text to the knowledgebase.
171+
155172
Args:
156-
text: str or list[str], the text to add to knowledgebase
157-
**kwargs:
158-
- tos_bucket_name: str, the bucket name of TOS
159-
- tos_bucket_path: str, the path of TOS bucket
173+
text (str | list[str]): The text to add to knowledgebase.
174+
tos_bucket_name (str | None, optional): The bucket name of TOS. Defaults to None.
175+
tos_bucket_path (str, optional): The path of TOS bucket. Defaults to "knowledgebase".
176+
177+
Returns:
178+
bool: True if successful, False otherwise.
160179
"""
161-
tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs)
180+
tos_bucket_name = tos_bucket_name or self.tos_config.bucket
162181
if isinstance(text, list):
163182
object_keys = kwargs.get(
164183
"tos_object_keys",
@@ -185,16 +204,26 @@ def add_from_text(self, text: str | list[str], **kwargs) -> bool:
185204
raise ValueError("text must be str or list[str]")
186205
return True
187206

188-
def add_from_bytes(self, content: bytes, file_name: str, **kwargs) -> bool:
189-
"""
207+
def add_from_bytes(
208+
self,
209+
content: bytes,
210+
file_name: str,
211+
tos_bucket_name: str | None = None,
212+
tos_bucket_path: str = "knowledgebase",
213+
**kwargs,
214+
) -> bool:
215+
"""Add knowledge from bytes to the knowledgebase.
216+
190217
Args:
191-
content: bytes, the content to add to knowledgebase, bytes
192-
file_name: str, the file name of the content
193-
**kwargs:
194-
- tos_bucket_name: str, the bucket name of TOS
195-
- tos_bucket_path: str, the path of TOS bucket
218+
content (bytes): The content to add to knowledgebase.
219+
file_name (str): The file name of the content.
220+
tos_bucket_name (str | None, optional): The bucket name of TOS. Defaults to None.
221+
tos_bucket_path (str, optional): The path of TOS bucket. Defaults to "knowledgebase".
222+
223+
Returns:
224+
bool: True if successful, False otherwise.
196225
"""
197-
tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs)
226+
tos_bucket_name = tos_bucket_name or self.tos_config.bucket
198227
tos_url = self._upload_bytes_to_tos(
199228
content,
200229
tos_bucket_name=tos_bucket_name,
@@ -346,7 +375,14 @@ def _upload_bytes_to_tos(
346375
self, content: bytes, tos_bucket_name: str, object_key: str
347376
) -> str:
348377
self._tos_client.bucket_name = tos_bucket_name
349-
asyncio.run(self._tos_client.upload(object_key=object_key, data=content))
378+
coro = self._tos_client.upload(object_key=object_key, data=content)
379+
try:
380+
loop = asyncio.get_running_loop()
381+
loop.run_until_complete(
382+
coro
383+
) if not loop.is_running() else asyncio.ensure_future(coro)
384+
except RuntimeError:
385+
asyncio.run(coro)
350386
return f"{self._tos_client.bucket_name}/{object_key}"
351387

352388
def _add_doc(self, tos_url: str) -> Any:

0 commit comments

Comments
 (0)