2323
2424import veadk .config # noqa E401
2525from veadk .config import getenv
26+ from veadk .configs .database_configs import NormalTOSConfig , TOSConfig
2627from veadk .consts import DEFAULT_TOS_BUCKET_NAME
2728from veadk .knowledgebase .backends .base_backend import BaseKnowledgebaseBackend
2829from veadk .knowledgebase .backends .utils import build_vikingdb_knowledgebase_request
@@ -62,12 +63,6 @@ def get_files_in_directory(directory: str):
6263 return file_paths
6364
6465
65- def _upload_bytes_to_tos (content : bytes , tos_bucket_name : str , object_key : str ) -> str :
66- ve_tos = VeTOS (bucket_name = tos_bucket_name )
67- asyncio .run (ve_tos .upload (object_key = object_key , data = content ))
68- return f"{ ve_tos .bucket_name } /{ object_key } "
69-
70-
7166class VikingDBKnowledgeBackend (BaseKnowledgebaseBackend ):
7267 volcengine_access_key : str = Field (
7368 default_factory = lambda : getenv ("VOLCENGINE_ACCESS_KEY" )
@@ -83,6 +78,9 @@ class VikingDBKnowledgeBackend(BaseKnowledgebaseBackend):
8378 region : str = "cn-beijing"
8479 """VikingDB knowledgebase region"""
8580
81+ tos_config : TOSConfig | NormalTOSConfig = Field (default_factory = TOSConfig )
82+ """TOS config, used to upload files to TOS"""
83+
8684 def precheck_index_naming (self ):
8785 if not (
8886 isinstance (self .index , str )
@@ -96,12 +94,20 @@ def precheck_index_naming(self):
9694
9795 def model_post_init (self , __context : Any ) -> None :
9896 self .precheck_index_naming ()
97+
9998 # check whether collection exist, if not, create it
10099 if not self .collection_status ()["existed" ]:
101100 logger .warning (
102101 f"VikingDB knowledgebase collection { self .index } does not exist, please create it first..."
103102 )
104103
104+ self ._tos_client = VeTOS (
105+ ak = self .volcengine_access_key ,
106+ sk = self .volcengine_secret_key ,
107+ region = self .tos_config .region ,
108+ bucket_name = self .tos_config .bucket ,
109+ )
110+
105111 @override
106112 def add_from_directory (self , directory : str , ** kwargs ) -> bool :
107113 """
@@ -115,7 +121,7 @@ def add_from_directory(self, directory: str, **kwargs) -> bool:
115121 files = get_files_in_directory (directory = directory )
116122 for _file in files :
117123 content , file_name = _read_file_to_bytes (_file )
118- tos_url = _upload_bytes_to_tos (
124+ tos_url = self . _upload_bytes_to_tos (
119125 content ,
120126 tos_bucket_name = tos_bucket_name ,
121127 object_key = f"{ tos_bucket_path } /{ file_name } " ,
@@ -135,7 +141,7 @@ def add_from_files(self, files: list[str], **kwargs) -> bool:
135141 tos_bucket_name , tos_bucket_path = _extract_tos_attributes (** kwargs )
136142 for _file in files :
137143 content , file_name = _read_file_to_bytes (_file )
138- tos_url = _upload_bytes_to_tos (
144+ tos_url = self . _upload_bytes_to_tos (
139145 content ,
140146 tos_bucket_name = tos_bucket_name ,
141147 object_key = f"{ tos_bucket_path } /{ file_name } " ,
@@ -163,15 +169,17 @@ def add_from_text(self, text: str | list[str], **kwargs) -> bool:
163169 )
164170 for _text , _object_key in zip (text , object_keys ):
165171 _content = _text .encode ("utf-8" )
166- tos_url = _upload_bytes_to_tos (_content , tos_bucket_name , _object_key )
172+ tos_url = self ._upload_bytes_to_tos (
173+ _content , tos_bucket_name , _object_key
174+ )
167175 self ._add_doc (tos_url = tos_url )
168176 return True
169177 elif isinstance (text , str ):
170178 content = text .encode ("utf-8" )
171179 object_key = kwargs .get (
172180 "object_key" , f"veadk/knowledgebase/{ formatted_timestamp ()} .txt"
173181 )
174- tos_url = _upload_bytes_to_tos (content , tos_bucket_name , object_key )
182+ tos_url = self . _upload_bytes_to_tos (content , tos_bucket_name , object_key )
175183 self ._add_doc (tos_url = tos_url )
176184 else :
177185 raise ValueError ("text must be str or list[str]" )
@@ -334,6 +342,13 @@ def create_collection(self) -> None:
334342 f"Error during collection creation: { response .get ('code' )} "
335343 )
336344
345+ def _upload_bytes_to_tos (
346+ self , content : bytes , tos_bucket_name : str , object_key : str
347+ ) -> str :
348+ self ._tos_client .bucket_name = tos_bucket_name
349+ asyncio .run (self ._tos_client .upload (object_key = object_key , data = content ))
350+ return f"{ self ._tos_client .bucket_name } /{ object_key } "
351+
337352 def _add_doc (self , tos_url : str ) -> Any :
338353 ADD_DOC_PATH = "/api/knowledge/doc/add"
339354
0 commit comments