Skip to content

Commit 297d296

Browse files
feat(auth): support sts token auth (#236)
* feat(auth): support sts token auth * add auto token for vikingdb memory
1 parent 57c90dc commit 297d296

File tree

6 files changed

+137
-69
lines changed

6 files changed

+137
-69
lines changed

veadk/auth/veauth/ark_veauth.py

Lines changed: 43 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,64 +14,56 @@
1414

1515
import os
1616

17-
from typing_extensions import override
18-
19-
from veadk.auth.veauth.base_veauth import BaseVeAuth
17+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
2018
from veadk.utils.logger import get_logger
2119
from veadk.utils.volcengine_sign import ve_request
2220

2321
logger = get_logger(__name__)
2422

2523

26-
class ARKVeAuth(BaseVeAuth):
27-
def __init__(
28-
self,
29-
access_key: str = os.getenv("VOLCENGINE_ACCESS_KEY", ""),
30-
secret_key: str = os.getenv("VOLCENGINE_SECRET_KEY", ""),
31-
) -> None:
32-
super().__init__(access_key, secret_key)
24+
def get_ark_token(region: str = "cn-beijing") -> str:
25+
logger.info("Fetching ARK token...")
3326

34-
self._token: str = ""
27+
access_key = os.getenv("VOLCENGINE_ACCESS_KEY")
28+
secret_key = os.getenv("VOLCENGINE_SECRET_KEY")
29+
session_token = ""
3530

36-
@override
37-
def _fetch_token(self) -> None:
38-
logger.info("Fetching ARK token...")
39-
# list api keys
40-
first_api_key_id = ""
41-
res = ve_request(
42-
request_body={"ProjectName": "default", "Filter": {}},
43-
action="ListApiKeys",
44-
ak=self.access_key,
45-
sk=self.secret_key,
46-
service="ark",
47-
version="2024-01-01",
48-
region="cn-beijing",
49-
host="open.volcengineapi.com",
50-
)
51-
try:
52-
first_api_key_id = res["Result"]["Items"][0]["Id"]
53-
except KeyError:
54-
raise ValueError(f"Failed to get ARK api key list: {res}")
31+
if not (access_key and secret_key):
32+
# try to get from vefaas iam
33+
cred = get_credential_from_vefaas_iam()
34+
access_key = cred.access_key_id
35+
secret_key = cred.secret_access_key
36+
session_token = cred.session_token
5537

56-
# get raw api key
57-
res = ve_request(
58-
request_body={"Id": first_api_key_id},
59-
action="GetRawApiKey",
60-
ak=self.access_key,
61-
sk=self.secret_key,
62-
service="ark",
63-
version="2024-01-01",
64-
region="cn-beijing",
65-
host="open.volcengineapi.com",
66-
)
67-
try:
68-
self._token = res["Result"]["ApiKey"]
69-
except KeyError:
70-
raise ValueError(f"Failed to get ARK api key: {res}")
38+
res = ve_request(
39+
request_body={"ProjectName": "default", "Filter": {}},
40+
header={"X-Security-Token": session_token},
41+
action="ListApiKeys",
42+
ak=access_key,
43+
sk=secret_key,
44+
service="ark",
45+
version="2024-01-01",
46+
region=region,
47+
host="open.volcengineapi.com",
48+
)
49+
try:
50+
first_api_key_id = res["Result"]["Items"][0]["Id"]
51+
except KeyError:
52+
raise ValueError(f"Failed to get ARK api key list: {res}")
7153

72-
@property
73-
def token(self) -> str:
74-
if self._token:
75-
return self._token
76-
self._fetch_token()
77-
return self._token
54+
# get raw api key
55+
res = ve_request(
56+
request_body={"Id": first_api_key_id},
57+
header={"X-Security-Token": session_token},
58+
action="GetRawApiKey",
59+
ak=access_key,
60+
sk=secret_key,
61+
service="ark",
62+
version="2024-01-01",
63+
region=region,
64+
host="open.volcengineapi.com",
65+
)
66+
try:
67+
return res["Result"]["ApiKey"]
68+
except KeyError:
69+
raise ValueError(f"Failed to get ARK api key: {res}")

veadk/auth/veauth/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from pathlib import Path
17+
18+
from pydantic import BaseModel
19+
20+
from veadk.consts import VEFAAS_IAM_CRIDENTIAL_PATH
21+
from veadk.utils.logger import get_logger
22+
23+
logger = get_logger(__name__)
24+
25+
26+
class VeIAMCredential(BaseModel):
27+
access_key_id: str
28+
secret_access_key: str
29+
session_token: str
30+
31+
32+
def get_credential_from_vefaas_iam() -> VeIAMCredential:
33+
"""Get credential from VeFaaS IAM file"""
34+
logger.info(
35+
f"Get Volcegnine access key or secret key from environment variables failed, try to get from VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH})."
36+
)
37+
38+
path = Path(VEFAAS_IAM_CRIDENTIAL_PATH)
39+
40+
if not path.exists():
41+
logger.error(
42+
f"Get Volcegnine access key or secret key from environment variables failed, and VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH}) not exists. Please check your configuration."
43+
)
44+
raise FileNotFoundError(
45+
f"Get Volcegnine access key or secret key from environment variables failed, and VeFaaS IAM file (path={VEFAAS_IAM_CRIDENTIAL_PATH}) not exists. Please check your configuration."
46+
)
47+
48+
with open(VEFAAS_IAM_CRIDENTIAL_PATH, "r") as f:
49+
cred_dict = json.load(f)
50+
access_key = cred_dict["access_key_id"]
51+
secret_key = cred_dict["secret_access_key"]
52+
session_token = cred_dict["session_token"]
53+
return VeIAMCredential(
54+
access_key_id=access_key,
55+
secret_access_key=secret_key,
56+
session_token=session_token,
57+
)

veadk/configs/model_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pydantic_settings import BaseSettings, SettingsConfigDict
1919

20-
from veadk.auth.veauth.ark_veauth import ARKVeAuth
20+
from veadk.auth.veauth.ark_veauth import get_ark_token
2121
from veadk.consts import (
2222
DEFAULT_MODEL_AGENT_API_BASE,
2323
DEFAULT_MODEL_AGENT_NAME,
@@ -39,7 +39,7 @@ class ModelConfig(BaseSettings):
3939

4040
@cached_property
4141
def api_key(self) -> str:
42-
return os.getenv("MODEL_AGENT_API_KEY") or ARKVeAuth().token
42+
return os.getenv("MODEL_AGENT_API_KEY") or get_ark_token()
4343

4444

4545
class EmbeddingModelConfig(BaseSettings):
@@ -56,7 +56,7 @@ class EmbeddingModelConfig(BaseSettings):
5656

5757
@cached_property
5858
def api_key(self) -> str:
59-
return os.getenv("MODEL_EMBEDDING_API_KEY") or ARKVeAuth().token
59+
return os.getenv("MODEL_EMBEDDING_API_KEY") or get_ark_token()
6060

6161

6262
class NormalEmbeddingModelConfig(BaseSettings):

veadk/consts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,5 @@
6666
DEFAULT_IMAGE_EDIT_MODEL_NAME = "doubao-seededit-3-0-i2i-250628"
6767
DEFAULT_VIDEO_MODEL_NAME = "doubao-seedance-1-0-pro-250528"
6868
DEFAULT_IMAGE_GENERATE_MODEL_NAME = "doubao-seedream-4-0-250828"
69+
70+
VEFAAS_IAM_CRIDENTIAL_PATH = "/var/run/secrets/iam/credential"

veadk/memory/long_term_memory_backends/vikingdb_memory_backend.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import os
1617
import re
1718
import time
1819
import uuid
@@ -22,7 +23,7 @@
2223
from typing_extensions import override
2324

2425
import veadk.config # noqa E401
25-
from veadk.config import getenv
26+
from veadk.auth.veauth.utils import get_credential_from_vefaas_iam
2627
from veadk.integrations.ve_viking_db_memory.ve_viking_db_memory import (
2728
VikingDBMemoryClient,
2829
)
@@ -35,24 +36,20 @@
3536

3637

3738
class VikingDBLTMBackend(BaseLongTermMemoryBackend):
38-
volcengine_access_key: str = Field(
39-
default_factory=lambda: getenv("VOLCENGINE_ACCESS_KEY")
39+
volcengine_access_key: str | None = Field(
40+
default_factory=lambda: os.getenv("VOLCENGINE_ACCESS_KEY")
4041
)
4142

42-
volcengine_secret_key: str = Field(
43-
default_factory=lambda: getenv("VOLCENGINE_SECRET_KEY")
43+
volcengine_secret_key: str | None = Field(
44+
default_factory=lambda: os.getenv("VOLCENGINE_SECRET_KEY")
4445
)
4546

47+
session_token: str = ""
48+
4649
region: str = "cn-beijing"
4750
"""VikingDB memory region"""
4851

4952
def model_post_init(self, __context: Any) -> None:
50-
self._client = VikingDBMemoryClient(
51-
ak=self.volcengine_access_key,
52-
sk=self.volcengine_secret_key,
53-
region=self.region,
54-
)
55-
5653
# check whether collection exist, if not, create it
5754
if not self._collection_exist():
5855
self._create_collection()
@@ -69,19 +66,35 @@ def precheck_index_naming(self):
6966

7067
def _collection_exist(self) -> bool:
7168
try:
72-
self._client.get_collection(collection_name=self.index)
69+
client = self._get_client()
70+
client.get_collection(collection_name=self.index)
7371
return True
7472
except Exception:
7573
return False
7674

7775
def _create_collection(self) -> None:
78-
response = self._client.create_collection(
76+
client = self._get_client()
77+
response = client.create_collection(
7978
collection_name=self.index,
8079
description="Created by Volcengine Agent Development Kit VeADK",
8180
builtin_event_types=["sys_event_v1"],
8281
)
8382
return response
8483

84+
def _get_client(self) -> VikingDBMemoryClient:
85+
if not (self.volcengine_access_key and self.volcengine_secret_key):
86+
cred = get_credential_from_vefaas_iam()
87+
self.volcengine_access_key = cred.access_key_id
88+
self.volcengine_secret_key = cred.secret_access_key
89+
self.session_token = cred.session_token
90+
91+
return VikingDBMemoryClient(
92+
ak=self.volcengine_access_key,
93+
sk=self.volcengine_secret_key,
94+
sts_token=self.session_token,
95+
region=self.region,
96+
)
97+
8598
@override
8699
def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
87100
session_id = str(uuid.uuid1())
@@ -103,7 +116,8 @@ def save_memory(self, user_id: str, event_strings: list[str], **kwargs) -> bool:
103116
f"Request for add {len(messages)} memory to VikingDB: collection_name={self.index}, metadata={metadata}, session_id={session_id}"
104117
)
105118

106-
response = self._client.add_messages(
119+
client = self._get_client()
120+
response = client.add_messages(
107121
collection_name=self.index,
108122
messages=messages,
109123
metadata=metadata,
@@ -130,7 +144,8 @@ def search_memory(
130144
f"Request for search memory in VikingDB: filter={filter}, collection_name={self.index}, query={query}, limit={top_k}"
131145
)
132146

133-
response = self._client.search_memory(
147+
client = self._get_client()
148+
response = client.search_memory(
134149
collection_name=self.index, query=query, filter=filter, limit=top_k
135150
)
136151

veadk/utils/volcengine_sign.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def request(method, date, query, header, ak, sk, action, body):
144144
)
145145
)
146146
header = {**header, **sign_result}
147+
if "X-Security-Token" in header and header["X-Security-Token"] == "":
148+
del header["X-Security-Token"]
147149
# header = {**header, **{"X-Security-Token": SessionToken}}
148150
# 第六步:将 Signature 签名写入 HTTP Header 中,并发送 HTTP 请求。
149151
r = requests.request(

0 commit comments

Comments
 (0)