Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions veadk/integrations/ve_tos/ve_tos.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ def __init__(
self,
ak: str = "",
sk: str = "",
session_token: str = "",
region: str = "cn-beijing",
bucket_name: str = DEFAULT_TOS_BUCKET_NAME,
) -> None:
self.ak = ak if ak else os.getenv("VOLCENGINE_ACCESS_KEY", "")
self.sk = sk if sk else os.getenv("VOLCENGINE_SECRET_KEY", "")
self.session_token = session_token

# Add empty value validation
if not self.ak or not self.sk:
raise ValueError(
Expand Down Expand Up @@ -71,6 +74,7 @@ def __init__(
self._client = self._tos_module.TosClientV2(
ak=self.ak,
sk=self.sk,
security_token=self.session_token,
endpoint=f"tos-{self.region}.volces.com",
region=self.region,
)
Expand All @@ -85,6 +89,7 @@ def _refresh_client(self):
self._client = self._tos_module.TosClientV2(
self.ak,
self.sk,
security_token=self.session_token,
endpoint=f"tos-{self.region}.volces.com",
region=self.region,
)
Expand Down
3 changes: 2 additions & 1 deletion veadk/knowledgebase/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def build_vikingdb_knowledgebase_request(
path: str,
volcengine_access_key: str,
volcengine_secret_key: str,
session_token: str = "",
method: Literal["GET", "POST", "PUT", "DELETE"] = "POST",
region: str = "cn-beijing",
params=None,
Expand Down Expand Up @@ -85,7 +86,7 @@ def build_vikingdb_knowledgebase_request(
r.set_body(json.dumps(data))

credentials = Credentials(
volcengine_access_key, volcengine_secret_key, "air", region
volcengine_access_key, volcengine_secret_key, "air", region, session_token
)
SignerV4.sign(r, credentials)
return r
66 changes: 47 additions & 19 deletions veadk/knowledgebase/backends/vikingdb_knowledge_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import asyncio
import json
import os
import re
from pathlib import Path
from typing import Any, Literal
Expand All @@ -23,7 +24,7 @@
from typing_extensions import override

import veadk.config # noqa E401
from veadk.config import getenv
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
from veadk.configs.database_configs import NormalTOSConfig, TOSConfig
from veadk.knowledgebase.backends.base_backend import BaseKnowledgebaseBackend
from veadk.knowledgebase.backends.utils import build_vikingdb_knowledgebase_request
Expand Down Expand Up @@ -58,14 +59,16 @@ def get_files_in_directory(directory: str):


class VikingDBKnowledgeBackend(BaseKnowledgebaseBackend):
volcengine_access_key: str = Field(
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
volcengine_access_key: str | None = Field(
default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
)

volcengine_secret_key: str = Field(
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
volcengine_secret_key: str | None = Field(
default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
)

session_token: str = ""

volcengine_project: str = "default"
"""VikingDB knowledgebase project in Volcengine console platform. Default by `default`"""

Expand All @@ -75,6 +78,15 @@ class VikingDBKnowledgeBackend(BaseKnowledgebaseBackend):
tos_config: TOSConfig | NormalTOSConfig = Field(default_factory=TOSConfig)
"""TOS config, used to upload files to TOS"""

def model_post_init(self, __context: Any) -> None:
self.precheck_index_naming()

# check whether collection exist, if not, create it
if not self.collection_status()["existed"]:
logger.warning(
f"VikingDB knowledgebase collection {self.index} does not exist, please create it first..."
)

def precheck_index_naming(self):
if not (
isinstance(self.index, str)
Expand All @@ -86,18 +98,21 @@ def precheck_index_naming(self):
"it must start with an English letter, contain only letters, numbers, and underscores, and have a length of 1-128."
)

def model_post_init(self, __context: Any) -> None:
self.precheck_index_naming()

# check whether collection exist, if not, create it
if not self.collection_status()["existed"]:
logger.warning(
f"VikingDB knowledgebase collection {self.index} does not exist, please create it first..."
)

self._tos_client = VeTOS(
ak=self.volcengine_access_key,
sk=self.volcengine_secret_key,
def _get_tos_client(self) -> VeTOS:
volcengine_access_key = self.volcengine_access_key
volcengine_secret_key = self.volcengine_secret_key
session_token = self.session_token

if not (volcengine_access_key and volcengine_secret_key):
cred = get_credential_from_vefaas_iam()
volcengine_access_key = cred.access_key_id
volcengine_secret_key = cred.secret_access_key
session_token = cred.session_token

return VeTOS(
ak=volcengine_access_key,
sk=volcengine_secret_key,
session_token=session_token,
region=self.tos_config.region,
bucket_name=self.tos_config.bucket,
)
Expand Down Expand Up @@ -404,6 +419,8 @@ def _upload_bytes_to_tos(
metadata: dict | None = None,
) -> str:
# Here, we set the metadata via the TOS object, ref: https://www.volcengine.com/docs/84313/1254624
self._tos_client = self._get_tos_client()

self._tos_client.bucket_name = tos_bucket_name
coro = self._tos_client.upload(
object_key=object_key,
Expand Down Expand Up @@ -504,10 +521,21 @@ def _do_request(
) -> dict:
VIKINGDB_KNOWLEDGEBASE_BASE_URL = "api-knowledgebase.mlp.cn-beijing.volces.com"

volcengine_access_key = self.volcengine_access_key
volcengine_secret_key = self.volcengine_secret_key
session_token = self.session_token

if not (volcengine_access_key and volcengine_secret_key):
cred = get_credential_from_vefaas_iam()
volcengine_access_key = cred.access_key_id
volcengine_secret_key = cred.secret_access_key
session_token = cred.session_token

request = build_vikingdb_knowledgebase_request(
path=path,
volcengine_access_key=self.volcengine_access_key,
volcengine_secret_key=self.volcengine_secret_key,
volcengine_access_key=volcengine_access_key,
volcengine_secret_key=volcengine_secret_key,
session_token=session_token,
method=method,
data=body,
)
Expand Down