Skip to content

Commit faab412

Browse files
chore(kb): add tos to viking knolwedgebase (#179)
1 parent e8eed5f commit faab412

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

veadk/knowledgebase/backends/vikingdb_knowledge_backend.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import veadk.config # noqa E401
2525
from veadk.config import getenv
26+
from veadk.configs.database_configs import NormalTOSConfig, TOSConfig
2627
from veadk.consts import DEFAULT_TOS_BUCKET_NAME
2728
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
2829
from veadk.knowledgebase.backends.utils import build_vikingdb_knowledgebase_request
@@ -62,12 +63,6 @@ def get_files_in_directory(directory: str):
6263
return file_paths
6364

6465

65-
def _upload_bytes_to_tos(content: bytes, tos_bucket_name: str, object_key: str) -> str:
66-
ve_tos = VeTOS(bucket_name=tos_bucket_name)
67-
asyncio.run(ve_tos.upload(object_key=object_key, data=content))
68-
return f"{ve_tos.bucket_name}/{object_key}"
69-
70-
7166
class VikingDBKnowledgeBackend(BaseKnowledgebaseBackend):
7267
volcengine_access_key: str = Field(
7368
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
@@ -83,6 +78,9 @@ class VikingDBKnowledgeBackend(BaseKnowledgebaseBackend):
8378
region: str = "cn-beijing"
8479
"""VikingDB knowledgebase region"""
8580

81+
tos_config: TOSConfig | NormalTOSConfig = Field(default_factory=TOSConfig)
82+
"""TOS config, used to upload files to TOS"""
83+
8684
def precheck_index_naming(self):
8785
if not (
8886
isinstance(self.index, str)
@@ -96,12 +94,20 @@ def precheck_index_naming(self):
9694

9795
def model_post_init(self, __context: Any) -> None:
9896
self.precheck_index_naming()
97+
9998
# check whether collection exist, if not, create it
10099
if not self.collection_status()["existed"]:
101100
logger.warning(
102101
f"VikingDB knowledgebase collection {self.index} does not exist, please create it first..."
103102
)
104103

104+
self._tos_client = VeTOS(
105+
ak=self.volcengine_access_key,
106+
sk=self.volcengine_secret_key,
107+
region=self.tos_config.region,
108+
bucket_name=self.tos_config.bucket,
109+
)
110+
105111
@override
106112
def add_from_directory(self, directory: str, **kwargs) -> bool:
107113
"""
@@ -115,7 +121,7 @@ def add_from_directory(self, directory: str, **kwargs) -> bool:
115121
files = get_files_in_directory(directory=directory)
116122
for _file in files:
117123
content, file_name = _read_file_to_bytes(_file)
118-
tos_url = _upload_bytes_to_tos(
124+
tos_url = self._upload_bytes_to_tos(
119125
content,
120126
tos_bucket_name=tos_bucket_name,
121127
object_key=f"{tos_bucket_path}/{file_name}",
@@ -135,7 +141,7 @@ def add_from_files(self, files: list[str], **kwargs) -> bool:
135141
tos_bucket_name, tos_bucket_path = _extract_tos_attributes(**kwargs)
136142
for _file in files:
137143
content, file_name = _read_file_to_bytes(_file)
138-
tos_url = _upload_bytes_to_tos(
144+
tos_url = self._upload_bytes_to_tos(
139145
content,
140146
tos_bucket_name=tos_bucket_name,
141147
object_key=f"{tos_bucket_path}/{file_name}",
@@ -163,15 +169,17 @@ def add_from_text(self, text: str | list[str], **kwargs) -> bool:
163169
)
164170
for _text, _object_key in zip(text, object_keys):
165171
_content = _text.encode("utf-8")
166-
tos_url = _upload_bytes_to_tos(_content, tos_bucket_name, _object_key)
172+
tos_url = self._upload_bytes_to_tos(
173+
_content, tos_bucket_name, _object_key
174+
)
167175
self._add_doc(tos_url=tos_url)
168176
return True
169177
elif isinstance(text, str):
170178
content = text.encode("utf-8")
171179
object_key = kwargs.get(
172180
"object_key", f"veadk/knowledgebase/{formatted_timestamp()}.txt"
173181
)
174-
tos_url = _upload_bytes_to_tos(content, tos_bucket_name, object_key)
182+
tos_url = self._upload_bytes_to_tos(content, tos_bucket_name, object_key)
175183
self._add_doc(tos_url=tos_url)
176184
else:
177185
raise ValueError("text must be str or list[str]")
@@ -334,6 +342,13 @@ def create_collection(self) -> None:
334342
f"Error during collection creation: {response.get('code')}"
335343
)
336344

345+
def _upload_bytes_to_tos(
346+
self, content: bytes, tos_bucket_name: str, object_key: str
347+
) -> str:
348+
self._tos_client.bucket_name = tos_bucket_name
349+
asyncio.run(self._tos_client.upload(object_key=object_key, data=content))
350+
return f"{self._tos_client.bucket_name}/{object_key}"
351+
337352
def _add_doc(self, tos_url: str) -> Any:
338353
ADD_DOC_PATH = "/api/knowledge/doc/add"
339354

0 commit comments

Comments
 (0)