Skip to content

Commit 9174185

Browse files
committed
feat: tos_backend
1 parent 4a92b15 commit 9174185

File tree

2 files changed

+221
-0
lines changed

2 files changed

+221
-0
lines changed

veadk/configs/database_configs.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,47 @@ class NormalTOSConfig(BaseSettings):
130130
region: str = "cn-beijing"
131131

132132
bucket: str
133+
134+
135+
class TOSVectorConfig(BaseSettings):
136+
model_config = SettingsConfigDict(env_prefix="DATABASE_TOS_VECTOR_")
137+
138+
endpoint: str = "tosvectors-cn-boe.volces.com"
139+
140+
region: str = "cn-beijing"
141+
142+
security_token: str | None = None
143+
144+
max_retry_count: int = 3
145+
146+
max_connections: int = 1024
147+
148+
connection_time: int = 10
149+
150+
enable_verify_ssl: bool = True
151+
152+
dns_cache_time: int = 15
153+
154+
proxy_host: str | None = None
155+
156+
proxy_port: int | None = None
157+
158+
proxy_username: str | None = None
159+
160+
proxy_password: str | None = None
161+
162+
high_latency_log_threshold: int = 100
163+
164+
socket_timeout: int = 30
165+
166+
credentials_provider: object | None = None
167+
168+
except100_continue_threshold: int = 65536
169+
170+
user_agent_product_name: str | None = None
171+
172+
user_agent_soft_name: str | None = None
173+
174+
user_agent_soft_version: str | None = None
175+
176+
user_agent_customized_key_values: dict[str, str] | None = None
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import json
15+
import os
16+
17+
from llama_index.core import (
18+
Document,
19+
SimpleDirectoryReader,
20+
)
21+
from llama_index.core.schema import BaseNode
22+
from llama_index.embeddings.openai_like import OpenAILikeEmbedding
23+
from pydantic import Field
24+
from tos.models2 import Vector, VectorData
25+
from typing_extensions import Any, override
26+
27+
import veadk.config # noqa E401
28+
from veadk.configs.database_configs import TOSVectorConfig
29+
from veadk.configs.model_configs import EmbeddingModelConfig, NormalEmbeddingModelConfig
30+
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
31+
from veadk.knowledgebase.backends.utils import get_llama_index_splitter
32+
33+
try:
34+
from tos.vector_client import VectorClient
35+
from tos import DataType, DistanceMetricType
36+
except ImportError:
37+
raise ImportError(
38+
"Please install VeADK extensions\npip install veadk-python[extensions]"
39+
)
40+
41+
42+
class TosVectorKnowledgeBackend(BaseKnowledgebaseBackend):
43+
"""TOS-based backend for knowledgebase."""
44+
45+
volcengine_access_key: str | None = Field(
46+
default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
47+
)
48+
volcengine_secret_key: str | None = Field(
49+
default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
50+
)
51+
tos_vector_bucket_name: str | None = Field(
52+
default_factory=lambda: os.getenv("DATABASE_TOS_VECTOR_BUCKET")
53+
)
54+
tos_vector_account_id: str | None = Field(
55+
default_factory=lambda: os.getenv("DATABASE_TOS_VECTOR_ACCOUNT_ID")
56+
)
57+
tos_vector_config: TOSVectorConfig = Field(default_factory=TOSVectorConfig)
58+
embedding_config: EmbeddingModelConfig | NormalEmbeddingModelConfig = Field(
59+
default_factory=EmbeddingModelConfig
60+
)
61+
62+
def model_post_init(self, __context: Any) -> None:
63+
self.precheck_index_naming()
64+
self._tos_client = VectorClient(
65+
ak=self.volcengine_access_key,
66+
sk=self.volcengine_secret_key,
67+
**self.tos_vector_config.model_dump(),
68+
)
69+
# create_bucket and index if not exist
70+
self._create_index()
71+
72+
self._embed_model = OpenAILikeEmbedding(
73+
model_name=self.embedding_config.name,
74+
api_key=self.embedding_config.api_key,
75+
api_base=self.embedding_config.api_base,
76+
)
77+
78+
def _bucket_exists(self) -> bool:
79+
bucket_list_resp = self._tos_client.list_vector_buckets()
80+
bucket_list = [
81+
bucket.vector_bucket_name for bucket in bucket_list_resp.vector_buckets
82+
]
83+
if self.tos_vector_bucket_name in bucket_list:
84+
return True
85+
else:
86+
return False
87+
88+
def _index_exists(self) -> bool:
89+
index_list_resp = self._tos_client.list_indexes(
90+
vector_bucket_name=self.tos_vector_bucket_name,
91+
account_id=self.tos_vector_account_id,
92+
)
93+
index_list = [index.index_name for index in index_list_resp.indexes]
94+
if self.index in index_list:
95+
return True
96+
else:
97+
return False
98+
99+
def _create_index(self):
100+
if not self._bucket_exists():
101+
self._tos_client.create_vector_bucket(
102+
vector_bucket_name=self.tos_vector_bucket_name,
103+
)
104+
if not self._index_exists():
105+
self._tos_client.create_index(
106+
vector_bucket_name=self.tos_vector_bucket_name,
107+
account_id=self.tos_vector_account_id,
108+
index_name=self.index,
109+
data_type=DataType.DataTypeFloat32,
110+
dimension=self.embedding_config.dim,
111+
distance_metric=DistanceMetricType.DistanceMetricCosine,
112+
)
113+
114+
def precheck_index_naming(self) -> None:
115+
pass
116+
117+
def _process_and_store_documents(self, documents: list[Document]) -> bool:
118+
nodes = self._split_documents(documents)
119+
vectors = []
120+
for node in nodes:
121+
embedding = self._embed_model.get_text_embedding(node.text)
122+
vectors.append(
123+
Vector(
124+
key=node.node_id,
125+
data=VectorData(float32=embedding),
126+
metadata={"text": node.text, "metadata": json.dumps(node.metadata)},
127+
)
128+
)
129+
result = self._tos_client.put_vectors(
130+
vector_bucket_name=self.tos_vector_bucket_name,
131+
account_id=self.tos_vector_account_id,
132+
index_name=self.index,
133+
vectors=vectors,
134+
)
135+
return result.status_code == 200
136+
137+
@override
138+
def add_from_directory(self, directory: str, *args, **kwargs) -> bool:
139+
documents = SimpleDirectoryReader(input_dir=directory).load_data()
140+
return self._process_and_store_documents(documents)
141+
142+
@override
143+
def add_from_files(self, files: list[str], *args, **kwargs) -> bool:
144+
documents = SimpleDirectoryReader(input_files=files).load_data()
145+
return self._process_and_store_documents(documents)
146+
147+
@override
148+
def add_from_text(self, text: str | list[str], *args, **kwargs) -> bool:
149+
if isinstance(text, str):
150+
documents = [Document(text=text)]
151+
else:
152+
documents = [Document(text=t) for t in text]
153+
154+
return self._process_and_store_documents(documents)
155+
156+
@override
157+
def search(self, query: str, top_k: int = 5) -> list[str]:
158+
query_vector = self._embed_model.get_text_embedding(query)
159+
160+
search_result = self._tos_client.query_vectors(
161+
vector_bucket_name=self.tos_vector_bucket_name,
162+
account_id=self.tos_vector_account_id,
163+
index_name=self.index,
164+
query_vector=VectorData(float32=query_vector),
165+
top_k=top_k,
166+
)
167+
168+
return [vector.metadata["text"] for vector in search_result.vectors]
169+
170+
def _split_documents(self, documents: list[Document]) -> list[BaseNode]:
171+
"""Split document into chunks"""
172+
nodes = []
173+
for document in documents:
174+
splitter = get_llama_index_splitter(document.metadata.get("file_path", ""))
175+
_nodes = splitter.get_nodes_from_documents([document])
176+
nodes.extend(_nodes)
177+
return nodes

0 commit comments

Comments
 (0)