Skip to content

Commit 452a59a

Browse files
committed
feat: add tos base
1 parent d2653b9 commit 452a59a

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

veadk/knowledgebase/backends/tos_vector_backend.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from typing_extensions import Any, override
2525

2626
import veadk.config # noqa E401
27+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
2728
from veadk.configs.database_configs import TOSVectorConfig
2829
from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig
30+
from veadk.integrations.ve_tos.ve_tos import VeTOS
2931
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
3032
from veadk.knowledgebase.backends.utils import get_llama_index_splitter
3133

@@ -56,20 +58,25 @@ class TosVectorKnowledgeBackend(BaseKnowledgebaseBackend):
5658
default_factory=lambda: os.getenv("DATABASE_TOS_VECTOR_ACCOUNT_ID")
5759
)
5860
tos_vector_config: TOSVectorConfig = Field(default_factory=TOSVectorConfig)
61+
62+
session_token: str = ""
63+
5964
embedding_config: EmbeddingModelConfig | NormalEmbeddingModelConfig = Field(
6065
default_factory=EmbeddingModelConfig
6166
)
6267

6368
def model_post_init(self, __context: Any) -> None:
6469
self.precheck_index_naming()
65-
self._tos_client = VectorClient(
70+
self._tos_vector_client = VectorClient(
6671
ak=self.volcengine_access_key,
6772
sk=self.volcengine_secret_key,
6873
**self.tos_vector_config.model_dump(),
6974
)
7075
# create_bucket and index if not exist
7176
self._create_index()
7277

78+
self._tos_client = self._get_tos_client()
79+
7380
self._embed_model = OpenAILikeEmbedding(
7481
model_name=self.embedding_config.name,
7582
api_key=self.embedding_config.api_key,
@@ -78,7 +85,7 @@ def model_post_init(self, __context: Any) -> None:
7885

7986
def _bucket_exists(self) -> bool:
8087
try:
81-
bucket_exist = self._tos_client.get_vector_bucket(
88+
bucket_exist = self._tos_vector_client.get_vector_bucket(
8289
vector_bucket_name=self.tos_vector_bucket_name,
8390
account_id=self.tos_vector_account_id,
8491
)
@@ -91,7 +98,7 @@ def _bucket_exists(self) -> bool:
9198

9299
def _index_exists(self) -> bool:
93100
try:
94-
index_exist = self._tos_client.get_index(
101+
index_exist = self._tos_vector_client.get_index(
95102
vector_bucket_name=self.tos_vector_bucket_name,
96103
account_id=self.tos_vector_account_id,
97104
index_name=self.index,
@@ -114,11 +121,11 @@ def _split_documents(self, documents: list[Document]) -> list[BaseNode]:
114121

115122
def _create_index(self):
116123
if not self._bucket_exists():
117-
self._tos_client.create_vector_bucket(
124+
self._tos_vector_client.create_vector_bucket(
118125
vector_bucket_name=self.tos_vector_bucket_name,
119126
)
120127
if not self._index_exists():
121-
self._tos_client.create_index(
128+
self._tos_vector_client.create_index(
122129
vector_bucket_name=self.tos_vector_bucket_name,
123130
account_id=self.tos_vector_account_id,
124131
index_name=self.index,
@@ -127,6 +134,25 @@ def _create_index(self):
127134
distance_metric=DistanceMetricType.DistanceMetricCosine,
128135
)
129136

137+
def _get_tos_client(self) -> VeTOS:
138+
volcengine_access_key = self.volcengine_access_key
139+
volcengine_secret_key = self.volcengine_secret_key
140+
session_token = self.session_token
141+
142+
if not (volcengine_access_key and volcengine_secret_key):
143+
cred = get_credential_from_vefaas_iam()
144+
volcengine_access_key = cred.access_key_id
145+
volcengine_secret_key = cred.secret_access_key
146+
session_token = cred.session_token
147+
148+
return VeTOS(
149+
ak=volcengine_access_key,
150+
sk=volcengine_secret_key,
151+
session_token=session_token,
152+
region=self.tos_vector_config.region,
153+
bucket_name=self.tos_vector_bucket_name,
154+
)
155+
130156
def precheck_index_naming(self) -> None:
131157
pass
132158

@@ -144,7 +170,7 @@ def _process_and_store_documents(self, documents: list[Document]) -> bool:
144170
metadata={"text": node.text, "metadata": json.dumps(node.metadata)},
145171
)
146172
)
147-
result = self._tos_client.put_vectors(
173+
result = self._tos_vector_client.put_vectors(
148174
vector_bucket_name=self.tos_vector_bucket_name,
149175
account_id=self.tos_vector_account_id,
150176
index_name=self.index,
@@ -175,7 +201,7 @@ def add_from_text(self, text: str | list[str], *args, **kwargs) -> bool:
175201
def search(self, query: str, top_k: int = 5) -> list[str]:
176202
query_vector = self._embed_model.get_text_embedding(query)
177203

178-
search_result = self._tos_client.query_vectors(
204+
search_result = self._tos_vector_client.query_vectors(
179205
vector_bucket_name=self.tos_vector_bucket_name,
180206
account_id=self.tos_vector_account_id,
181207
index_name=self.index,

0 commit comments

Comments
 (0)