2424from typing_extensions import Any , override
2525
2626import veadk .config # noqa E401
27+ from veadk .auth .veauth .utils import get_credential_from_vefaas_iam
2728from veadk .configs .database_configs import TOSVectorConfig
2829from veadk .configs .model_configs import EmbeddingModelConfig , NormalEmbeddingModelConfig
30+ from veadk .integrations .ve_tos .ve_tos import VeTOS
2931from veadk .knowledgebase .backends .base_backend import BaseKnowledgebaseBackend
3032from veadk .knowledgebase .backends .utils import get_llama_index_splitter
3133
@@ -56,20 +58,25 @@ class TosVectorKnowledgeBackend(BaseKnowledgebaseBackend):
5658 default_factory = lambda : os .getenv ("DATABASE_TOS_VECTOR_ACCOUNT_ID" )
5759 )
5860 tos_vector_config : TOSVectorConfig = Field (default_factory = TOSVectorConfig )
61+
62+ session_token : str = ""
63+
5964 embedding_config : EmbeddingModelConfig | NormalEmbeddingModelConfig = Field (
6065 default_factory = EmbeddingModelConfig
6166 )
6267
6368 def model_post_init (self , __context : Any ) -> None :
6469 self .precheck_index_naming ()
65- self ._tos_client = VectorClient (
70+ self ._tos_vector_client = VectorClient (
6671 ak = self .volcengine_access_key ,
6772 sk = self .volcengine_secret_key ,
6873 ** self .tos_vector_config .model_dump (),
6974 )
7075 # create_bucket and index if not exist
7176 self ._create_index ()
7277
78+ self ._tos_client = self ._get_tos_client ()
79+
7380 self ._embed_model = OpenAILikeEmbedding (
7481 model_name = self .embedding_config .name ,
7582 api_key = self .embedding_config .api_key ,
@@ -78,7 +85,7 @@ def model_post_init(self, __context: Any) -> None:
7885
7986 def _bucket_exists (self ) -> bool :
8087 try :
81- bucket_exist = self ._tos_client .get_vector_bucket (
88+ bucket_exist = self ._tos_vector_client .get_vector_bucket (
8289 vector_bucket_name = self .tos_vector_bucket_name ,
8390 account_id = self .tos_vector_account_id ,
8491 )
@@ -91,7 +98,7 @@ def _bucket_exists(self) -> bool:
9198
9299 def _index_exists (self ) -> bool :
93100 try :
94- index_exist = self ._tos_client .get_index (
101+ index_exist = self ._tos_vector_client .get_index (
95102 vector_bucket_name = self .tos_vector_bucket_name ,
96103 account_id = self .tos_vector_account_id ,
97104 index_name = self .index ,
@@ -114,11 +121,11 @@ def _split_documents(self, documents: list[Document]) -> list[BaseNode]:
114121
115122 def _create_index (self ):
116123 if not self ._bucket_exists ():
117- self ._tos_client .create_vector_bucket (
124+ self ._tos_vector_client .create_vector_bucket (
118125 vector_bucket_name = self .tos_vector_bucket_name ,
119126 )
120127 if not self ._index_exists ():
121- self ._tos_client .create_index (
128+ self ._tos_vector_client .create_index (
122129 vector_bucket_name = self .tos_vector_bucket_name ,
123130 account_id = self .tos_vector_account_id ,
124131 index_name = self .index ,
@@ -127,6 +134,25 @@ def _create_index(self):
127134 distance_metric = DistanceMetricType .DistanceMetricCosine ,
128135 )
129136
137+ def _get_tos_client (self ) -> VeTOS :
138+ volcengine_access_key = self .volcengine_access_key
139+ volcengine_secret_key = self .volcengine_secret_key
140+ session_token = self .session_token
141+
142+ if not (volcengine_access_key and volcengine_secret_key ):
143+ cred = get_credential_from_vefaas_iam ()
144+ volcengine_access_key = cred .access_key_id
145+ volcengine_secret_key = cred .secret_access_key
146+ session_token = cred .session_token
147+
148+ return VeTOS (
149+ ak = volcengine_access_key ,
150+ sk = volcengine_secret_key ,
151+ session_token = session_token ,
152+ region = self .tos_vector_config .region ,
153+ bucket_name = self .tos_vector_bucket_name ,
154+ )
155+
130156 def precheck_index_naming (self ) -> None :
131157 pass
132158
@@ -144,7 +170,7 @@ def _process_and_store_documents(self, documents: list[Document]) -> bool:
144170 metadata = {"text" : node .text , "metadata" : json .dumps (node .metadata )},
145171 )
146172 )
147- result = self ._tos_client .put_vectors (
173+ result = self ._tos_vector_client .put_vectors (
148174 vector_bucket_name = self .tos_vector_bucket_name ,
149175 account_id = self .tos_vector_account_id ,
150176 index_name = self .index ,
@@ -175,7 +201,7 @@ def add_from_text(self, text: str | list[str], *args, **kwargs) -> bool:
175201 def search (self , query : str , top_k : int = 5 ) -> list [str ]:
176202 query_vector = self ._embed_model .get_text_embedding (query )
177203
178- search_result = self ._tos_client .query_vectors (
204+ search_result = self ._tos_vector_client .query_vectors (
179205 vector_bucket_name = self .tos_vector_bucket_name ,
180206 account_id = self .tos_vector_account_id ,
181207 index_name = self .index ,
0 commit comments